Merge pull request #357 from selwin/job-class

Make it easier to use custom job class
main
Vincent Driessen 11 years ago
commit 69adec5bc7

@ -6,7 +6,7 @@ import uuid
from .connections import resolve_connection from .connections import resolve_connection
from .job import Job, Status from .job import Job, Status
from .utils import utcnow from .utils import import_attribute, utcnow
from .exceptions import (DequeueTimeout, InvalidJobOperationError, from .exceptions import (DequeueTimeout, InvalidJobOperationError,
NoSuchJobError, UnpickleError) NoSuchJobError, UnpickleError)
@ -55,7 +55,7 @@ class Queue(object):
return cls(name, connection=connection) return cls(name, connection=connection)
def __init__(self, name='default', default_timeout=None, connection=None, def __init__(self, name='default', default_timeout=None, connection=None,
async=True): async=True, job_class=None):
self.connection = resolve_connection(connection) self.connection = resolve_connection(connection)
prefix = self.redis_queue_namespace_prefix prefix = self.redis_queue_namespace_prefix
self.name = name self.name = name
@ -63,6 +63,11 @@ class Queue(object):
self._default_timeout = default_timeout self._default_timeout = default_timeout
self._async = async self._async = async
if job_class is not None:
if isinstance(job_class, string_types):
job_class = import_attribute(job_class)
self.job_class = job_class
@property @property
def key(self): def key(self):
"""Returns the Redis key for this Queue.""" """Returns the Redis key for this Queue."""

@ -27,6 +27,7 @@ def parse_args():
parser.add_argument('--burst', '-b', action='store_true', default=False, help='Run in burst mode (quit after all work is done)') # noqa parser.add_argument('--burst', '-b', action='store_true', default=False, help='Run in burst mode (quit after all work is done)') # noqa
parser.add_argument('--name', '-n', default=None, help='Specify a different name') parser.add_argument('--name', '-n', default=None, help='Specify a different name')
parser.add_argument('--worker-class', '-w', action='store', default='rq.Worker', help='RQ Worker class to use') parser.add_argument('--worker-class', '-w', action='store', default='rq.Worker', help='RQ Worker class to use')
parser.add_argument('--job-class', '-j', action='store', default='rq.job.Job', help='RQ Job class to use')
parser.add_argument('--path', '-P', default='.', help='Specify the import path.') parser.add_argument('--path', '-P', default='.', help='Specify the import path.')
parser.add_argument('--results-ttl', default=None, help='Default results timeout to be used') parser.add_argument('--results-ttl', default=None, help='Default results timeout to be used')
parser.add_argument('--worker-ttl', default=None, help='Default worker timeout to be used') parser.add_argument('--worker-ttl', default=None, help='Default worker timeout to be used')
@ -88,7 +89,8 @@ def main():
w = worker_class(queues, w = worker_class(queues,
name=args.name, name=args.name,
default_worker_ttl=args.worker_ttl, default_worker_ttl=args.worker_ttl,
default_result_ttl=args.results_ttl) default_result_ttl=args.results_ttl,
job_class=args.job_class)
# Should we configure Sentry? # Should we configure Sentry?
if args.sentry_dsn: if args.sentry_dsn:

@ -12,7 +12,7 @@ import sys
import time import time
import traceback import traceback
from rq.compat import as_text, text_type from rq.compat import as_text, string_types, text_type
from .connections import get_current_connection from .connections import get_current_connection
from .exceptions import DequeueTimeout, NoQueueError from .exceptions import DequeueTimeout, NoQueueError
@ -20,7 +20,7 @@ from .job import Job, Status
from .logutils import setup_loghandlers from .logutils import setup_loghandlers
from .queue import get_failed_queue, Queue from .queue import get_failed_queue, Queue
from .timeouts import UnixSignalDeathPenalty from .timeouts import UnixSignalDeathPenalty
from .utils import make_colorizer, utcformat, utcnow from .utils import import_attribute, make_colorizer, utcformat, utcnow
from .version import VERSION from .version import VERSION
try: try:
@ -109,7 +109,7 @@ class Worker(object):
def __init__(self, queues, name=None, def __init__(self, queues, name=None,
default_result_ttl=None, connection=None, default_result_ttl=None, connection=None,
exc_handler=None, default_worker_ttl=None): # noqa exc_handler=None, default_worker_ttl=None, job_class=None): # noqa
if connection is None: if connection is None:
connection = get_current_connection() connection = get_current_connection()
self.connection = connection self.connection = connection
@ -141,6 +141,11 @@ class Worker(object):
if exc_handler is not None: if exc_handler is not None:
self.push_exc_handler(exc_handler) self.push_exc_handler(exc_handler)
if job_class is not None:
if isinstance(job_class, string_types):
job_class = import_attribute(job_class)
self.job_class = job_class
def validate_queues(self): def validate_queues(self):
"""Sanity check for the given queues.""" """Sanity check for the given queues."""
if not iterable(self.queues): if not iterable(self.queues):

@ -12,6 +12,10 @@ from tests.fixtures import (div_by_zero, echo, Number, say_hello,
some_calculation) some_calculation)
class CustomJob(Job):
pass
class TestQueue(RQTestCase): class TestQueue(RQTestCase):
def test_create_queue(self): def test_create_queue(self):
"""Creating queues.""" """Creating queues."""
@ -437,3 +441,8 @@ class TestFailedQueue(RQTestCase):
q = Queue(async=False) q = Queue(async=False)
job = q.enqueue(some_calculation, args=(2, 3)) job = q.enqueue(some_calculation, args=(2, 3))
self.assertEqual(job.return_value, 6) self.assertEqual(job.return_value, 6)
def test_custom_job_class(self):
"""Ensure custom job class assignment works as expected."""
q = Queue(job_class=CustomJob)
self.assertEqual(q.job_class, CustomJob)

@ -14,6 +14,10 @@ from tests.fixtures import (create_file, create_file_after_timeout, div_by_zero,
from tests.helpers import strip_microseconds from tests.helpers import strip_microseconds
class CustomJob(Job):
pass
class TestWorker(RQTestCase): class TestWorker(RQTestCase):
def test_create_worker(self): def test_create_worker(self):
"""Worker creation.""" """Worker creation."""
@ -269,3 +273,9 @@ class TestWorker(RQTestCase):
as_text(self.testconn.hget(worker.key, 'current_job')) as_text(self.testconn.hget(worker.key, 'current_job'))
) )
self.assertEqual(worker.get_current_job(), job) self.assertEqual(worker.get_current_job(), job)
def test_custom_job_class(self):
"""Ensure Worker accepts custom job class."""
q = Queue()
worker = Worker([q], job_class=CustomJob)
self.assertEqual(worker.job_class, CustomJob)

Loading…
Cancel
Save