Worker.__init__ should accept custom job class.

main
Selwin Ong 11 years ago
parent 141278bb42
commit 7ac1c3500a

@ -12,7 +12,7 @@ import sys
import time import time
import traceback import traceback
from rq.compat import as_text, text_type from rq.compat import as_text, string_types, text_type
from .connections import get_current_connection from .connections import get_current_connection
from .exceptions import DequeueTimeout, NoQueueError from .exceptions import DequeueTimeout, NoQueueError
@ -20,7 +20,7 @@ from .job import Job, Status
from .logutils import setup_loghandlers from .logutils import setup_loghandlers
from .queue import get_failed_queue, Queue from .queue import get_failed_queue, Queue
from .timeouts import UnixSignalDeathPenalty from .timeouts import UnixSignalDeathPenalty
from .utils import make_colorizer, utcformat, utcnow from .utils import import_attribute, make_colorizer, utcformat, utcnow
from .version import VERSION from .version import VERSION
try: try:
@ -109,7 +109,7 @@ class Worker(object):
def __init__(self, queues, name=None, def __init__(self, queues, name=None,
default_result_ttl=None, connection=None, default_result_ttl=None, connection=None,
exc_handler=None, default_worker_ttl=None): # noqa exc_handler=None, default_worker_ttl=None, job_class=None): # noqa
if connection is None: if connection is None:
connection = get_current_connection() connection = get_current_connection()
self.connection = connection self.connection = connection
@ -141,6 +141,11 @@ class Worker(object):
if exc_handler is not None: if exc_handler is not None:
self.push_exc_handler(exc_handler) self.push_exc_handler(exc_handler)
if job_class is not None:
if isinstance(job_class, string_types):
job_class = import_attribute(job_class)
self.job_class = job_class
def validate_queues(self): def validate_queues(self):
"""Sanity check for the given queues.""" """Sanity check for the given queues."""
if not iterable(self.queues): if not iterable(self.queues):

@ -14,6 +14,10 @@ from tests.fixtures import (create_file, create_file_after_timeout, div_by_zero,
from tests.helpers import strip_microseconds from tests.helpers import strip_microseconds
class CustomJob(Job):
pass
class TestWorker(RQTestCase): class TestWorker(RQTestCase):
def test_create_worker(self): def test_create_worker(self):
"""Worker creation.""" """Worker creation."""
@ -269,3 +273,9 @@ class TestWorker(RQTestCase):
as_text(self.testconn.hget(worker.key, 'current_job')) as_text(self.testconn.hget(worker.key, 'current_job'))
) )
self.assertEqual(worker.get_current_job(), job) self.assertEqual(worker.get_current_job(), job)
def test_custom_job_class(self):
"""Ensure Worker accepts custom job class."""
q = Queue()
worker = Worker([q], job_class=CustomJob)
self.assertEqual(worker.job_class, CustomJob)

Loading…
Cancel
Save