From e1cbc3736c55d7a72be121a170d1b8c811619dc2 Mon Sep 17 00:00:00 2001 From: Babatunde Olusola Date: Thu, 16 Apr 2020 13:53:48 +0100 Subject: [PATCH] Implement Customizable Serializer Support (#1219) * Implement Customizable Serializer Support * Refractor serializer instance methods * Update tests with other serializers * Edit function description * Edit function description * Raise appropriate exception * Update tests for better code coverage * Remove un-used imports and un-necessary code * Refractor resolve_serializer * Remove un-necessary alias from imports * Add documentation * Refractor tests, improve documentation --- .travis.yml | 3 +- docs/docs/jobs.md | 14 ++++++++ docs/docs/workers.md | 24 +++++++++++++ rq/exceptions.py | 6 ---- rq/job.py | 74 +++++++++++++-------------------------- rq/queue.py | 19 +++++----- rq/serializers.py | 25 +++++++++++++ rq/worker.py | 16 +++++---- tests/fixtures.py | 7 ++++ tests/test_job.py | 34 ++++++++++-------- tests/test_queue.py | 9 +++++ tests/test_serializers.py | 37 ++++++++++++++++++++ tests/test_worker.py | 9 +++++ 13 files changed, 190 insertions(+), 87 deletions(-) create mode 100644 rq/serializers.py create mode 100644 tests/test_serializers.py diff --git a/.travis.yml b/.travis.yml index 168e736..1c4fe39 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,13 +3,12 @@ services: - redis matrix: include: - - python: "2.7" - python: "3.4" - python: "3.5" - python: "3.6" - python: "3.7" - python: "3.8" - - python: "pypy" + - python: "pypy3" install: - pip install -e . - pip install pytest-cov sentry-sdk codecov diff --git a/docs/docs/jobs.md b/docs/docs/jobs.md index 16c9860..b91ab21 100644 --- a/docs/docs/jobs.md +++ b/docs/docs/jobs.md @@ -82,6 +82,20 @@ job = Job.create(count_words_at_url, }) ``` +## Job / Queue Creation with Custom Serializer + +When creating a job or queue, you can pass in a custom serializer that will be used for serializing / de-serializing job arguments. +Serializers used should have at least `loads` and `dumps` method. +The default serializer used is `pickle` + +```python +import json +from rq import Job, Queue + +job = Job(connection=connection, serializer=json) +queue = Queue(connection=connection, serializer=json) +``` + ## Retrieving a Job from Redis All job information is stored in Redis. You can inspect a job and its attributes diff --git a/docs/docs/workers.md b/docs/docs/workers.md index dc4fb6e..25d0095 100644 --- a/docs/docs/workers.md +++ b/docs/docs/workers.md @@ -202,6 +202,30 @@ queue = Queue('queue_name', connection=redis) workers = Worker.all(queue=queue) ``` +## Worker with Custom Serializer + +When creating a worker, you can pass in a custom serializer that will be implicitly passed to the queue. +Serializers used should have at least `loads` and `dumps` method. +The default serializer used is `pickle` + +```python +import json +from rq import Worker + +job = Worker('foo', serializer=json) +``` + +or when creating from a queue + +```python +import json +from rq import Queue, Worker + +w = Worker(Queue('foo'), serializer=json) +``` + +Queues will now use custom serializer + ### Worker Statistics diff --git a/rq/exceptions.py b/rq/exceptions.py index 684bfb0..ba97f7c 100644 --- a/rq/exceptions.py +++ b/rq/exceptions.py @@ -19,12 +19,6 @@ class InvalidJobOperation(Exception): pass -class UnpickleError(Exception): - def __init__(self, message, raw_data, inner_exception=None): - super(UnpickleError, self).__init__(message, inner_exception) - self.raw_data = raw_data - - class DequeueTimeout(Exception): pass diff --git a/rq/job.py b/rq/job.py index 69c18e1..15e706b 100644 --- a/rq/job.py +++ b/rq/job.py @@ -5,27 +5,17 @@ from __future__ import (absolute_import, division, print_function, import inspect import warnings import zlib -from functools import partial from uuid import uuid4 from rq.compat import as_text, decode_redis_hash, string_types, text_type from .connections import resolve_connection -from .exceptions import InvalidJobDependency, NoSuchJobError, UnpickleError +from .exceptions import NoSuchJobError from .local import LocalStack from .utils import (enum, import_attribute, parse_timeout, str_to_date, utcformat, utcnow) +from .serializers import resolve_serializer -try: - import cPickle as pickle -except ImportError: # noqa # pragma: no cover - import pickle - - -# Serialize pickle dumps using the highest pickle protocol (binary, default -# uses ascii) -dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) -loads = pickle.loads JobStatus = enum( 'JobStatus', @@ -42,21 +32,6 @@ JobStatus = enum( UNEVALUATED = object() -def unpickle(pickled_string): - """Unpickles a string, but raises a unified UnpickleError in case anything - fails. - - This is a helper method to not have to deal with the fact that `loads()` - potentially raises many types of exceptions (e.g. AttributeError, - IndexError, TypeError, KeyError, etc.) - """ - try: - obj = loads(pickled_string) - except Exception as e: - raise UnpickleError('Could not unpickle', pickled_string, e) - return obj - - def cancel_job(job_id, connection=None): """Cancels the job with the given job ID, preventing execution. Discards any job info (i.e. it can't be requeued later). @@ -89,7 +64,7 @@ class Job(object): def create(cls, func, args=None, kwargs=None, connection=None, result_ttl=None, ttl=None, status=None, description=None, depends_on=None, timeout=None, id=None, origin=None, meta=None, - failure_ttl=None): + failure_ttl=None, serializer=None): """Creates a new Job instance for the given function, arguments, and keyword arguments. """ @@ -103,7 +78,7 @@ class Job(object): if not isinstance(kwargs, dict): raise TypeError('{0!r} is not a valid kwargs dict'.format(kwargs)) - job = cls(connection=connection) + job = cls(connection=connection, serializer=serializer) if id is not None: job.set_id(id) @@ -214,8 +189,8 @@ class Job(object): return import_attribute(self.func_name) - def _unpickle_data(self): - self._func_name, self._instance, self._args, self._kwargs = unpickle(self.data) + def _deserialize_data(self): + self._func_name, self._instance, self._args, self._kwargs = self.serializer.loads(self.data) @property def data(self): @@ -233,7 +208,7 @@ class Job(object): self._kwargs = {} job_tuple = self._func_name, self._instance, self._args, self._kwargs - self._data = dumps(job_tuple) + self._data = self.serializer.dumps(job_tuple) return self._data @data.setter @@ -247,7 +222,7 @@ class Job(object): @property def func_name(self): if self._func_name is UNEVALUATED: - self._unpickle_data() + self._deserialize_data() return self._func_name @func_name.setter @@ -258,7 +233,7 @@ class Job(object): @property def instance(self): if self._instance is UNEVALUATED: - self._unpickle_data() + self._deserialize_data() return self._instance @instance.setter @@ -269,7 +244,7 @@ class Job(object): @property def args(self): if self._args is UNEVALUATED: - self._unpickle_data() + self._deserialize_data() return self._args @args.setter @@ -280,7 +255,7 @@ class Job(object): @property def kwargs(self): if self._kwargs is UNEVALUATED: - self._unpickle_data() + self._deserialize_data() return self._kwargs @kwargs.setter @@ -295,11 +270,11 @@ class Job(object): return conn.exists(cls.key_for(job_id)) @classmethod - def fetch(cls, id, connection=None): + def fetch(cls, id, connection=None, serializer=None): """Fetches a persisted job from its corresponding Redis key and instantiates it. """ - job = cls(id, connection=connection) + job = cls(id, connection=connection, serializer=serializer) job.refresh() return job @@ -327,7 +302,7 @@ class Job(object): return jobs - def __init__(self, id=None, connection=None): + def __init__(self, id=None, connection=None, serializer=None): self.connection = resolve_connection(connection) self._id = id self.created_at = utcnow() @@ -350,6 +325,7 @@ class Job(object): self._status = None self._dependency_ids = [] self.meta = {} + self.serializer = resolve_serializer(serializer) def __repr__(self): # noqa # pragma: no cover return '{0}({1!r}, enqueued_at={2!r})'.format(self.__class__.__name__, @@ -451,7 +427,7 @@ class Job(object): rv = self.connection.hget(self.key, 'result') if rv is not None: # cache the result - self._result = loads(rv) + self._result = self.serializer.loads(rv) return self._result """Backwards-compatibility accessor property `return_value`.""" @@ -480,9 +456,9 @@ class Job(object): result = obj.get('result') if result: try: - self._result = unpickle(obj.get('result')) - except UnpickleError: - self._result = 'Unpickleable return value' + self._result = self.serializer.loads(obj.get('result')) + except Exception as e: + self._result = "Unserializable return value" self.timeout = parse_timeout(obj.get('timeout')) if obj.get('timeout') else None self.result_ttl = int(obj.get('result_ttl')) if obj.get('result_ttl') else None # noqa self.failure_ttl = int(obj.get('failure_ttl')) if obj.get('failure_ttl') else None # noqa @@ -492,7 +468,7 @@ class Job(object): self._dependency_ids = [as_text(dependency_id)] if dependency_id else [] self.ttl = int(obj.get('ttl')) if obj.get('ttl') else None - self.meta = unpickle(obj.get('meta')) if obj.get('meta') else {} + self.meta = self.serializer.loads(obj.get('meta')) if obj.get('meta') else {} raw_exc_info = obj.get('exc_info') if raw_exc_info: @@ -536,9 +512,9 @@ class Job(object): obj['ended_at'] = utcformat(self.ended_at) if self.ended_at else '' if self._result is not None: try: - obj['result'] = dumps(self._result) - except: - obj['result'] = 'Unpickleable return value' + obj['result'] = self.serializer.dumps(self._result) + except Exception as e: + obj['result'] = "Unserializable return value" if self.exc_info is not None: obj['exc_info'] = zlib.compress(str(self.exc_info).encode('utf-8')) if self.timeout is not None: @@ -552,7 +528,7 @@ class Job(object): if self._dependency_ids: obj['dependency_id'] = self._dependency_ids[0] if self.meta and include_meta: - obj['meta'] = dumps(self.meta) + obj['meta'] = self.serializer.dumps(self.meta) if self.ttl: obj['ttl'] = self.ttl @@ -575,7 +551,7 @@ class Job(object): def save_meta(self): """Stores job meta from the job instance to the corresponding Redis key.""" - meta = dumps(self.meta) + meta = self.serializer.dumps(self.meta) self.connection.hset(self.key, 'meta', meta) def cancel(self, pipeline=None): diff --git a/rq/queue.py b/rq/queue.py index cff48fb..4316653 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -12,9 +12,10 @@ from redis import WatchError from .compat import as_text, string_types, total_ordering, utc from .connections import resolve_connection from .defaults import DEFAULT_RESULT_TTL -from .exceptions import DequeueTimeout, NoSuchJobError, UnpickleError +from .exceptions import DequeueTimeout, NoSuchJobError from .job import Job, JobStatus from .utils import backend_class, import_attribute, parse_timeout, utcnow +from .serializers import resolve_serializer def compact(lst): @@ -29,7 +30,7 @@ class Queue(object): redis_queues_keys = 'rq:queues' @classmethod - def all(cls, connection=None, job_class=None): + def all(cls, connection=None, job_class=None, serializer=None): """Returns an iterable of all Queues. """ connection = resolve_connection(connection) @@ -37,13 +38,13 @@ class Queue(object): def to_queue(queue_key): return cls.from_queue_key(as_text(queue_key), connection=connection, - job_class=job_class) + job_class=job_class, serializer=serializer) return [to_queue(rq_key) for rq_key in connection.smembers(cls.redis_queues_keys) if rq_key] @classmethod - def from_queue_key(cls, queue_key, connection=None, job_class=None): + def from_queue_key(cls, queue_key, connection=None, job_class=None, serializer=None): """Returns a Queue instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Queues by their Redis keys. @@ -52,10 +53,10 @@ class Queue(object): if not queue_key.startswith(prefix): raise ValueError('Not a valid RQ queue key: {0}'.format(queue_key)) name = queue_key[len(prefix):] - return cls(name, connection=connection, job_class=job_class) + return cls(name, connection=connection, job_class=job_class, serializer=serializer) def __init__(self, name='default', default_timeout=None, connection=None, - is_async=True, job_class=None, **kwargs): + is_async=True, job_class=None, serializer=None, **kwargs): self.connection = resolve_connection(connection) prefix = self.redis_queue_namespace_prefix self.name = name @@ -73,6 +74,8 @@ class Queue(object): job_class = import_attribute(job_class) self.job_class = job_class + self.serializer = resolve_serializer(serializer) + def __len__(self): return self.count @@ -269,7 +272,7 @@ class Queue(object): result_ttl=result_ttl, ttl=ttl, failure_ttl=failure_ttl, status=status, description=description, depends_on=depends_on, timeout=timeout, id=job_id, - origin=self.name, meta=meta + origin=self.name, meta=meta, serializer=self.serializer ) return job @@ -552,7 +555,7 @@ class Queue(object): # Silently pass on jobs that don't exist (anymore), # and continue in the look continue - except UnpickleError as e: + except Exception as e: # Attach queue information on the exception for improved error # reporting e.job_id = job_id diff --git a/rq/serializers.py b/rq/serializers.py new file mode 100644 index 0000000..c4b0e54 --- /dev/null +++ b/rq/serializers.py @@ -0,0 +1,25 @@ +import pickle + +from .compat import string_types +from .utils import import_attribute + + +def resolve_serializer(serializer): + """This function checks the user defined serializer for ('dumps', 'loads') methods + It returns a default pickle serializer if not found else it returns a MySerializer + The returned serializer objects implement ('dumps', 'loads') methods + Also accepts a string path to serializer that will be loaded as the serializer + """ + if not serializer: + return pickle + + if isinstance(serializer, string_types): + serializer = import_attribute(serializer) + + default_serializer_methods = ('dumps', 'loads') + + for instance_method in default_serializer_methods: + if not hasattr(serializer, instance_method): + raise NotImplementedError('Serializer should have (dumps, loads) methods.') + + return serializer diff --git a/rq/worker.py b/rq/worker.py index 874640d..50d1733 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -41,6 +41,7 @@ from .utils import (backend_class, ensure_list, enum, make_colorizer, utcformat, utcnow, utcparse) from .version import VERSION from .worker_registration import clean_worker_registry, get_keys +from .serializers import resolve_serializer try: from setproctitle import setproctitle as setprocname @@ -104,7 +105,7 @@ class Worker(object): log_job_description = True @classmethod - def all(cls, connection=None, job_class=None, queue_class=None, queue=None): + def all(cls, connection=None, job_class=None, queue_class=None, queue=None, serializer=None): """Returns an iterable of all Workers. """ if queue: @@ -116,7 +117,7 @@ class Worker(object): workers = [cls.find_by_key(as_text(key), connection=connection, job_class=job_class, - queue_class=queue_class) + queue_class=queue_class, serializer=serializer) for key in worker_keys] return compact(workers) @@ -132,7 +133,7 @@ class Worker(object): @classmethod def find_by_key(cls, worker_key, connection=None, job_class=None, - queue_class=None): + queue_class=None, serializer=None): """Returns a Worker instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Workers by their Redis keys. @@ -149,7 +150,7 @@ class Worker(object): name = worker_key[len(prefix):] worker = cls([], name, connection=connection, job_class=job_class, - queue_class=queue_class, prepare_for_work=False) + queue_class=queue_class, prepare_for_work=False, serializer=serializer) worker.refresh() @@ -161,7 +162,7 @@ class Worker(object): queue_class=None, log_job_description=True, job_monitoring_interval=DEFAULT_JOB_MONITORING_INTERVAL, disable_default_exception_handler=False, - prepare_for_work=True): # noqa + prepare_for_work=True, serializer=None): # noqa if connection is None: connection = get_current_connection() self.connection = connection @@ -177,10 +178,11 @@ class Worker(object): self.queue_class = backend_class(self, 'queue_class', override=queue_class) self.version = VERSION self.python_version = sys.version + self.serializer = resolve_serializer(serializer) queues = [self.queue_class(name=q, connection=connection, - job_class=self.job_class) + job_class=self.job_class, serializer=self.serializer) if isinstance(q, string_types) else q for q in ensure_list(queues)] @@ -644,7 +646,7 @@ class Worker(object): if queues: self.queues = [self.queue_class(queue, connection=self.connection, - job_class=self.job_class) + job_class=self.job_class, serializer=self.serializer) for queue in queues.split(',')] def increment_failed_job_count(self, pipeline=None): diff --git a/tests/fixtures.py b/tests/fixtures.py index 46cdaac..a1f507a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -175,3 +175,10 @@ def kill_worker(pid, double_kill, interval=0.5): # give the worker time to switch signal handler time.sleep(interval) os.kill(pid, signal.SIGTERM) + + +class Serializer(object): + def loads(self): pass + + def dumps(self): pass + diff --git a/tests/test_job.py b/tests/test_job.py index a7a1b27..25c37e2 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -2,15 +2,16 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -import sys +import json import time +import queue import zlib from datetime import datetime from redis import WatchError from rq.compat import PY2, as_text -from rq.exceptions import NoSuchJobError, UnpickleError +from rq.exceptions import NoSuchJobError from rq.job import Job, JobStatus, cancel_job, get_current_job from rq.queue import Queue from rq.registry import (DeferredJobRegistry, FailedJobRegistry, @@ -20,16 +21,7 @@ from rq.utils import utcformat from rq.worker import Worker from tests import RQTestCase, fixtures -is_py2 = sys.version[0] == '2' -if is_py2: - import Queue as queue -else: - import queue as queue - -try: - from cPickle import loads, dumps -except ImportError: - from pickle import loads, dumps +from pickle import loads, dumps class TestJob(RQTestCase): @@ -117,6 +109,17 @@ class TestJob(RQTestCase): self.assertEqual(job.instance, n) self.assertEqual(job.args, (4,)) + def test_create_job_with_serializer(self): + """Creation of jobs with serializer for instance methods.""" + # Test using json serializer + n = fixtures.Number(2) + job = Job.create(func=n.div, args=(4,), serializer=json) + + self.assertIsNotNone(job.serializer) + self.assertEqual(job.func, n.div) + self.assertEqual(job.instance, n) + self.assertEqual(job.args, (4,)) + def test_create_job_from_string_function(self): """Creation of jobs using string specifier.""" job = Job.create(func='tests.fixtures.say_hello', args=('World',)) @@ -273,7 +276,7 @@ class TestJob(RQTestCase): job.refresh() for attr in ('func_name', 'instance', 'args', 'kwargs'): - with self.assertRaises(UnpickleError): + with self.assertRaises(Exception): getattr(job, attr) def test_job_is_unimportable(self): @@ -371,13 +374,14 @@ class TestJob(RQTestCase): job = Job.create(func=fixtures.say_hello, args=('Lionel',)) job._result = queue.Queue() job.save() + self.assertEqual( self.testconn.hget(job.key, 'result').decode('utf-8'), - 'Unpickleable return value' + 'Unserializable return value' ) job = Job.fetch(job.id) - self.assertEqual(job.result, 'Unpickleable return value') + self.assertEqual(job.result, 'Unserializable return value') def test_result_ttl_is_persisted(self): """Ensure that job's result_ttl is set properly""" diff --git a/tests/test_queue.py b/tests/test_queue.py index ffe08c5..7004aad 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -2,6 +2,7 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) +import json from datetime import datetime, timedelta from rq import Queue @@ -29,6 +30,14 @@ class TestQueue(RQTestCase): self.assertEqual(q.name, 'my-queue') self.assertEqual(str(q), '') + def test_create_queue_with_serializer(self): + """Creating queues with serializer.""" + # Test using json serializer + q = Queue('queue-with-serializer', serializer=json) + self.assertEqual(q.name, 'queue-with-serializer') + self.assertEqual(str(q), '') + self.assertIsNotNone(q.serializer) + def test_create_default_queue(self): """Instantiating the default queue.""" q = Queue() diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 0000000..58d093f --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import json +import pickle +import queue +import unittest + +from rq.serializers import resolve_serializer + + +class TestSerializers(unittest.TestCase): + def test_resolve_serializer(self): + """Ensure function resolve_serializer works correctly""" + serializer = resolve_serializer(None) + self.assertIsNotNone(serializer) + self.assertEqual(serializer, pickle) + + # Test using json serializer + serializer = resolve_serializer(json) + self.assertIsNotNone(serializer) + + self.assertTrue(hasattr(serializer, 'dumps')) + self.assertTrue(hasattr(serializer, 'loads')) + + # Test raise NotImplmentedError + with self.assertRaises(NotImplementedError): + resolve_serializer(object) + + # Test raise Exception + with self.assertRaises(Exception): + resolve_serializer(queue.Queue()) + + # Test using path.to.serializer string + serializer = resolve_serializer('tests.fixtures.Serializer') + self.assertIsNotNone(serializer) diff --git a/tests/test_worker.py b/tests/test_worker.py index 5788d25..088213d 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -2,6 +2,7 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) +import json import os import shutil import signal @@ -97,6 +98,14 @@ class TestWorker(RQTestCase): self.assertEqual(w.queues[0].name, 'foo') self.assertEqual(w.queues[1].name, 'bar') + # With string and serializer + w = Worker('foo', serializer=json) + self.assertEqual(w.queues[0].name, 'foo') + + # With queue having serializer + w = Worker(Queue('foo'), serializer=json) + self.assertEqual(w.queues[0].name, 'foo') + def test_work_and_quit(self): """Worker processes work, then quits.""" fooq, barq = Queue('foo'), Queue('bar')