From 05ed85804ca381cd3db165e29d48a9dbb1b6b695 Mon Sep 17 00:00:00 2001 From: Antoine Leclair Date: Mon, 15 Feb 2016 22:42:24 -0500 Subject: [PATCH] Worker accepts custom queue class --- rq/cli/cli.py | 1 + rq/worker.py | 18 ++++++++++++------ tests/test_worker.py | 26 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/rq/cli/cli.py b/rq/cli/cli.py index 6b52bd7..bb3c5a6 100755 --- a/rq/cli/cli.py +++ b/rq/cli/cli.py @@ -191,6 +191,7 @@ def worker(url, config, burst, name, worker_class, job_class, queue_class, path, default_worker_ttl=worker_ttl, default_result_ttl=results_ttl, job_class=job_class, + queue_class=queue_class, exception_handlers=exception_handlers or None) # Should we configure Sentry? diff --git a/rq/worker.py b/rq/worker.py index af77e4e..ed78165 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -123,11 +123,22 @@ class Worker(object): def __init__(self, queues, name=None, default_result_ttl=None, connection=None, exc_handler=None, - exception_handlers=None, default_worker_ttl=None, job_class=None): # noqa + exception_handlers=None, default_worker_ttl=None, + job_class=None, queue_class=None): # noqa if connection is None: connection = get_current_connection() self.connection = connection + if job_class is not None: + if isinstance(job_class, string_types): + job_class = import_attribute(job_class) + self.job_class = job_class + + if queue_class is not None: + if isinstance(queue_class, string_types): + queue_class = import_attribute(queue_class) + self.queue_class = queue_class + queues = [self.queue_class(name=q) if isinstance(q, string_types) else q for q in ensure_list(queues)] self._name = name @@ -167,11 +178,6 @@ class Worker(object): elif exception_handlers is not None: self.push_exc_handler(exception_handlers) - if job_class is not None: - if isinstance(job_class, string_types): - job_class = import_attribute(job_class) - self.job_class = job_class - def validate_queues(self): """Sanity check for the given queues.""" for queue in self.queues: diff --git a/tests/test_worker.py b/tests/test_worker.py index 5921970..458a30a 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -26,6 +26,10 @@ class CustomJob(Job): pass +class CustomQueue(Queue): + pass + + class TestWorker(RQTestCase): def test_create_worker(self): """Worker creation using various inputs.""" @@ -347,6 +351,28 @@ class TestWorker(RQTestCase): worker = Worker([q], job_class=CustomJob) self.assertEqual(worker.job_class, CustomJob) + def test_custom_queue_class(self): + """Ensure Worker accepts custom queue class.""" + q = CustomQueue() + worker = Worker([q], queue_class=CustomQueue) + self.assertEqual(worker.queue_class, CustomQueue) + + def test_custom_queue_class_by_string(self): + """Ensure Worker accepts custom queue class using dotted notation.""" + q = CustomQueue() + worker = Worker([q], queue_class='test_worker.CustomQueue') + self.assertEqual(worker.queue_class, CustomQueue) + + def test_custom_queue_class_is_not_global(self): + """Ensure Worker custom queue class is not global.""" + q = CustomQueue() + worker_custom = Worker([q], queue_class=CustomQueue) + q_generic = Queue() + worker_generic = Worker([q_generic]) + self.assertEqual(worker_custom.queue_class, CustomQueue) + self.assertEqual(worker_generic.queue_class, Queue) + self.assertEqual(Worker.queue_class, Queue) + def test_work_via_simpleworker(self): """Worker processes work, with forking disabled, then returns."""