diff --git a/rq/worker.py b/rq/worker.py index a269620..8f13723 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -12,7 +12,7 @@ import sys import time import traceback -from rq.compat import as_text, text_type +from rq.compat import as_text, string_types, text_type from .connections import get_current_connection from .exceptions import DequeueTimeout, NoQueueError @@ -20,7 +20,7 @@ from .job import Job, Status from .logutils import setup_loghandlers from .queue import get_failed_queue, Queue from .timeouts import UnixSignalDeathPenalty -from .utils import make_colorizer, utcformat, utcnow +from .utils import import_attribute, make_colorizer, utcformat, utcnow from .version import VERSION try: @@ -109,7 +109,7 @@ class Worker(object): def __init__(self, queues, name=None, default_result_ttl=None, connection=None, - exc_handler=None, default_worker_ttl=None): # noqa + exc_handler=None, default_worker_ttl=None, job_class=None): # noqa if connection is None: connection = get_current_connection() self.connection = connection @@ -141,6 +141,11 @@ class Worker(object): if exc_handler is not None: self.push_exc_handler(exc_handler) + if job_class is not None: + if isinstance(job_class, string_types): + job_class = import_attribute(job_class) + self.job_class = job_class + def validate_queues(self): """Sanity check for the given queues.""" if not iterable(self.queues): diff --git a/tests/test_worker.py b/tests/test_worker.py index 27e85bb..5609841 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -14,6 +14,10 @@ from tests.fixtures import (create_file, create_file_after_timeout, div_by_zero, from tests.helpers import strip_microseconds +class CustomJob(Job): + pass + + class TestWorker(RQTestCase): def test_create_worker(self): """Worker creation.""" @@ -269,3 +273,9 @@ class TestWorker(RQTestCase): as_text(self.testconn.hget(worker.key, 'current_job')) ) self.assertEqual(worker.get_current_job(), job) + + def test_custom_job_class(self): + """Ensure Worker accepts custom job class.""" + q = Queue() + worker = Worker([q], job_class=CustomJob) + self.assertEqual(worker.job_class, CustomJob)