diff --git a/rq/worker.py b/rq/worker.py index c52c21c..c364b18 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -129,9 +129,10 @@ class Worker(object): connection=connection, job_class=job_class, queue_class=queue_class) - queues = as_text(connection.hget(worker.key, 'queues')) - worker._state = as_text(connection.hget(worker.key, 'state') or '?') - worker._job_id = connection.hget(worker.key, 'current_job') or None + queues, state, job_id = connection.hmget(worker.key, 'queues', 'state', 'current_job') + queues = as_text(queues) + worker._state = as_text(state or '?') + worker._job_id = job_id or None if queues: worker.queues = [worker.queue_class(queue, connection=connection, @@ -139,9 +140,8 @@ class Worker(object): for queue in queues.split(',')] return worker - def __init__(self, queues, name=None, - default_result_ttl=None, connection=None, exc_handler=None, - exception_handlers=None, default_worker_ttl=None, + def __init__(self, queues, name=None, default_result_ttl=None, connection=None, + exc_handler=None, exception_handlers=None, default_worker_ttl=None, job_class=None, queue_class=None): # noqa if connection is None: connection = get_current_connection() diff --git a/tests/test_worker.py b/tests/test_worker.py index 02c6aa6..19d44a8 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -13,7 +13,6 @@ import subprocess import sys import pytest - import mock from mock import Mock @@ -31,7 +30,7 @@ from rq.job import Job, JobStatus from rq.registry import StartedJobRegistry from rq.suspension import resume, suspend from rq.utils import utcnow -from rq.worker import HerokuWorker +from rq.worker import HerokuWorker, WorkerStatus class CustomJob(Job): @@ -43,6 +42,7 @@ class CustomQueue(Queue): class TestWorker(RQTestCase): + def test_create_worker(self): """Worker creation using various inputs.""" @@ -105,10 +105,23 @@ class TestWorker(RQTestCase): 'Expected at least some work done.' ) + def test_find_by_key(self): + """Worker.find_by_key restores queues, state and job_id.""" + queues = [Queue('foo'), Queue('bar')] + w = Worker(queues) + w.register_death() + w.register_birth() + w.set_state(WorkerStatus.STARTED) + worker = Worker.find_by_key(w.key) + self.assertEqual(worker.queues, queues) + self.assertEqual(worker.get_state(), WorkerStatus.STARTED) + self.assertEqual(worker._job_id, None) + w.register_death() + def test_worker_ttl(self): """Worker ttl.""" w = Worker([]) - w.register_birth() # ugly: our test should only call public APIs + w.register_birth() [worker_key] = self.testconn.smembers(Worker.redis_workers_keys) self.assertIsNotNone(self.testconn.ttl(worker_key)) w.register_death() @@ -908,7 +921,7 @@ class HerokuWorkerShutdownTestCase(TimeoutTestCase, RQTestCase): w = HerokuWorker('foo') w._horse_pid = 19999 - w.handle_warm_shutdown_request() + w.handle_warm_shutdown_request() class TestExceptionHandlerMessageEncoding(RQTestCase):