diff --git a/rq/queue.py b/rq/queue.py index 4296001..baac6ca 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -6,7 +6,7 @@ import uuid from .connections import resolve_connection from .job import Job, Status -from .utils import utcnow +from .utils import import_attribute, utcnow from .exceptions import (DequeueTimeout, InvalidJobOperationError, NoSuchJobError, UnpickleError) @@ -55,7 +55,7 @@ class Queue(object): return cls(name, connection=connection) def __init__(self, name='default', default_timeout=None, connection=None, - async=True): + async=True, job_class=None): self.connection = resolve_connection(connection) prefix = self.redis_queue_namespace_prefix self.name = name @@ -63,6 +63,11 @@ class Queue(object): self._default_timeout = default_timeout self._async = async + if job_class is not None: + if isinstance(job_class, string_types): + job_class = import_attribute(job_class) + self.job_class = job_class + @property def key(self): """Returns the Redis key for this Queue.""" diff --git a/rq/scripts/rqworker.py b/rq/scripts/rqworker.py index 84f68f7..01850dd 100644 --- a/rq/scripts/rqworker.py +++ b/rq/scripts/rqworker.py @@ -27,6 +27,7 @@ def parse_args(): parser.add_argument('--burst', '-b', action='store_true', default=False, help='Run in burst mode (quit after all work is done)') # noqa parser.add_argument('--name', '-n', default=None, help='Specify a different name') parser.add_argument('--worker-class', '-w', action='store', default='rq.Worker', help='RQ Worker class to use') + parser.add_argument('--job-class', '-j', action='store', default='rq.job.Job', help='RQ Job class to use') parser.add_argument('--path', '-P', default='.', help='Specify the import path.') parser.add_argument('--results-ttl', default=None, help='Default results timeout to be used') parser.add_argument('--worker-ttl', default=None, help='Default worker timeout to be used') @@ -88,7 +89,8 @@ def main(): w = worker_class(queues, name=args.name, default_worker_ttl=args.worker_ttl, - default_result_ttl=args.results_ttl) + default_result_ttl=args.results_ttl, + job_class=args.job_class) # Should we configure Sentry? if args.sentry_dsn: diff --git a/rq/worker.py b/rq/worker.py index a269620..8f13723 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -12,7 +12,7 @@ import sys import time import traceback -from rq.compat import as_text, text_type +from rq.compat import as_text, string_types, text_type from .connections import get_current_connection from .exceptions import DequeueTimeout, NoQueueError @@ -20,7 +20,7 @@ from .job import Job, Status from .logutils import setup_loghandlers from .queue import get_failed_queue, Queue from .timeouts import UnixSignalDeathPenalty -from .utils import make_colorizer, utcformat, utcnow +from .utils import import_attribute, make_colorizer, utcformat, utcnow from .version import VERSION try: @@ -109,7 +109,7 @@ class Worker(object): def __init__(self, queues, name=None, default_result_ttl=None, connection=None, - exc_handler=None, default_worker_ttl=None): # noqa + exc_handler=None, default_worker_ttl=None, job_class=None): # noqa if connection is None: connection = get_current_connection() self.connection = connection @@ -141,6 +141,11 @@ class Worker(object): if exc_handler is not None: self.push_exc_handler(exc_handler) + if job_class is not None: + if isinstance(job_class, string_types): + job_class = import_attribute(job_class) + self.job_class = job_class + def validate_queues(self): """Sanity check for the given queues.""" if not iterable(self.queues): diff --git a/tests/test_queue.py b/tests/test_queue.py index 1d467a6..d368590 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -12,6 +12,10 @@ from tests.fixtures import (div_by_zero, echo, Number, say_hello, some_calculation) +class CustomJob(Job): + pass + + class TestQueue(RQTestCase): def test_create_queue(self): """Creating queues.""" @@ -437,3 +441,8 @@ class TestFailedQueue(RQTestCase): q = Queue(async=False) job = q.enqueue(some_calculation, args=(2, 3)) self.assertEqual(job.return_value, 6) + + def test_custom_job_class(self): + """Ensure custom job class assignment works as expected.""" + q = Queue(job_class=CustomJob) + self.assertEqual(q.job_class, CustomJob) diff --git a/tests/test_worker.py b/tests/test_worker.py index 27e85bb..5609841 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -14,6 +14,10 @@ from tests.fixtures import (create_file, create_file_after_timeout, div_by_zero, from tests.helpers import strip_microseconds +class CustomJob(Job): + pass + + class TestWorker(RQTestCase): def test_create_worker(self): """Worker creation.""" @@ -269,3 +273,9 @@ class TestWorker(RQTestCase): as_text(self.testconn.hget(worker.key, 'current_job')) ) self.assertEqual(worker.get_current_job(), job) + + def test_custom_job_class(self): + """Ensure Worker accepts custom job class.""" + q = Queue() + worker = Worker([q], job_class=CustomJob) + self.assertEqual(worker.job_class, CustomJob)