diff --git a/rq/worker.py b/rq/worker.py index db68ee2..286403e 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -21,6 +21,7 @@ except ImportError: from redis import WatchError +from . import worker_registration from .compat import PY2, as_text, string_types, text_type from .connections import get_current_connection, push_connection, pop_connection from .defaults import DEFAULT_RESULT_TTL, DEFAULT_WORKER_TTL @@ -34,6 +35,7 @@ from .timeouts import UnixSignalDeathPenalty from .utils import (backend_class, ensure_list, enum, make_colorizer, utcformat, utcnow, utcparse) from .version import VERSION +from .worker_registration import get_keys try: from procname import setprocname @@ -90,7 +92,7 @@ WorkerStatus = enum( class Worker(object): redis_worker_namespace_prefix = 'rq:worker:' - redis_workers_keys = 'rq:workers' + redis_workers_keys = worker_registration.REDIS_WORKER_KEYS death_penalty_class = UnixSignalDeathPenalty queue_class = Queue job_class = Job @@ -99,19 +101,32 @@ class Worker(object): log_result_lifespan = True @classmethod - def all(cls, connection=None, job_class=None, queue_class=None): + def all(cls, connection=None, job_class=None, queue_class=None, queue=None): """Returns an iterable of all Workers. """ - if connection is None: + if queue: + connection = queue.connection + elif connection is None: connection = get_current_connection() - reported_working = connection.smembers(cls.redis_workers_keys) + + worker_keys = get_keys(queue=queue, connection=connection) workers = [cls.find_by_key(as_text(key), connection=connection, job_class=job_class, queue_class=queue_class) - for key in reported_working] + for key in worker_keys] return compact(workers) + @classmethod + def all_keys(cls, connection=None, queue=None): + return [as_text(key) + for key in get_keys(queue=queue, connection=connection)] + + @classmethod + def count(cls, connection=None, queue=None): + """Returns the number of workers by queue or connection""" + return len(get_keys(queue=queue, connection=connection)) + @classmethod def find_by_key(cls, worker_key, connection=None, job_class=None, queue_class=None): @@ -121,7 +136,7 @@ class Worker(object): """ prefix = cls.redis_worker_namespace_prefix if not worker_key.startswith(prefix): - raise ValueError('Not a valid RQ worker key: {0}'.format(worker_key)) + raise ValueError('Not a valid RQ worker key: %s' % worker_key) if connection is None: connection = get_current_connection() @@ -188,7 +203,7 @@ class Worker(object): if exc_handler is not None: self.push_exc_handler(exc_handler) warnings.warn( - "use of exc_handler is deprecated, pass a list to exception_handlers instead.", + "exc_handler is deprecated, pass a list to exception_handlers instead.", DeprecationWarning ) elif isinstance(exception_handlers, list): @@ -271,7 +286,7 @@ class Worker(object): p.hset(key, 'birth', now_in_string) p.hset(key, 'last_heartbeat', now_in_string) p.hset(key, 'queues', queues) - p.sadd(self.redis_workers_keys, key) + worker_registration.register(self, p) p.expire(key, self.default_worker_ttl) p.execute() @@ -281,7 +296,7 @@ class Worker(object): with self.connection._pipeline() as p: # We cannot use self.state = 'dead' here, because that would # rollback the pipeline - p.srem(self.redis_workers_keys, self.key) + worker_registration.unregister(self, p) p.hset(self.key, 'death', utcformat(utcnow())) p.expire(self.key, 60) p.execute() diff --git a/rq/worker_registration.py b/rq/worker_registration.py new file mode 100644 index 0000000..73cb0ef --- /dev/null +++ b/rq/worker_registration.py @@ -0,0 +1,45 @@ +from .compat import as_text + + +WORKERS_BY_QUEUE_KEY = 'rq:workers:%s' +REDIS_WORKER_KEYS = 'rq:workers' + + +def register(worker, pipeline=None): + """Store worker key in Redis so we can easily discover active workers.""" + connection = pipeline if pipeline is not None else worker.connection + connection.sadd(worker.redis_workers_keys, worker.key) + for name in worker.queue_names(): + redis_key = WORKERS_BY_QUEUE_KEY % name + connection.sadd(redis_key, worker.key) + + +def unregister(worker, pipeline=None): + """Remove worker key from Redis.""" + if pipeline is None: + connection = worker.connection._pipeline() + else: + connection = pipeline + + connection.srem(worker.redis_workers_keys, worker.key) + for name in worker.queue_names(): + redis_key = WORKERS_BY_QUEUE_KEY % name + connection.srem(redis_key, worker.key) + + if pipeline is None: + connection.execute() + + +def get_keys(queue=None, connection=None): + """Returnes a list of worker keys for a queue""" + if queue is None and connection is None: + raise ValueError('"queue" or "connection" argument is required') + + if queue: + redis = queue.connection + redis_key = WORKERS_BY_QUEUE_KEY % queue.name + else: + redis = connection + redis_key = REDIS_WORKER_KEYS + + return {as_text(key) for key in redis.smembers(redis_key)} diff --git a/tests/test_worker.py b/tests/test_worker.py index 9b8bb81..5367023 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -108,6 +108,26 @@ class TestWorker(RQTestCase): 'Expected at least some work done.' ) + def test_worker_all(self): + """Worker.all() works properly""" + foo_queue = Queue('foo') + bar_queue = Queue('bar') + + w1 = Worker([foo_queue, bar_queue], name='w1') + w1.register_birth() + w2 = Worker([foo_queue], name='w2') + w2.register_birth() + + self.assertEqual( + set(Worker.all(connection=foo_queue.connection)), + set([w1, w2]) + ) + self.assertEqual(set(Worker.all(queue=foo_queue)), set([w1, w2])) + self.assertEqual(set(Worker.all(queue=bar_queue)), set([w1])) + + w1.register_death() + w2.register_death() + def test_find_by_key(self): """Worker.find_by_key restores queues, state and job_id.""" queues = [Queue('foo'), Queue('bar')] @@ -119,7 +139,12 @@ class TestWorker(RQTestCase): self.assertEqual(worker.queues, queues) self.assertEqual(worker.get_state(), WorkerStatus.STARTED) self.assertEqual(worker._job_id, None) - w.register_death() + self.assertTrue(worker.key in Worker.all_keys(worker.connection)) + + # If worker is gone, its keys should also be removed + worker.connection.delete(worker.key) + Worker.find_by_key(worker.key) + self.assertFalse(worker.key in Worker.all_keys(worker.connection)) def test_worker_ttl(self): """Worker ttl.""" @@ -183,7 +208,7 @@ class TestWorker(RQTestCase): # importable from the worker process. job = Job.create(func=div_by_zero, args=(3,)) job.save() - + job_data = job.data invalid_data = job_data.replace(b'div_by_zero', b'nonexisting') assert job_data != invalid_data diff --git a/tests/test_worker_registration.py b/tests/test_worker_registration.py new file mode 100644 index 0000000..1fe6612 --- /dev/null +++ b/tests/test_worker_registration.py @@ -0,0 +1,70 @@ +from tests import RQTestCase + +from rq import Queue, Worker +from rq.worker_registration import (get_keys, register, unregister, + WORKERS_BY_QUEUE_KEY) + + +class TestWorkerRegistry(RQTestCase): + + def test_worker_registration(self): + """Ensure worker.key is correctly set in Redis.""" + foo_queue = Queue(name='foo') + bar_queue = Queue(name='bar') + worker = Worker([foo_queue, bar_queue]) + + register(worker) + redis = worker.connection + + self.assertTrue(redis.sismember(worker.redis_workers_keys, worker.key)) + self.assertTrue( + redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key) + ) + self.assertTrue( + redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key) + ) + + unregister(worker) + self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key)) + self.assertFalse( + redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key) + ) + self.assertFalse( + redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key) + ) + + def test_get_keys_by_queue(self): + """get_keys_by_queue only returns active workers for that queue""" + foo_queue = Queue(name='foo') + bar_queue = Queue(name='bar') + baz_queue = Queue(name='baz') + + worker1 = Worker([foo_queue, bar_queue]) + worker2 = Worker([foo_queue]) + worker3 = Worker([baz_queue]) + + self.assertEqual(set(), get_keys(foo_queue)) + + register(worker1) + register(worker2) + register(worker3) + + # get_keys(queue) will return worker keys for that queue + self.assertEqual( + set([worker1.key, worker2.key]), + get_keys(foo_queue) + ) + self.assertEqual(set([worker1.key]), get_keys(bar_queue)) + + # get_keys(connection=connection) will return all worker keys + self.assertEqual( + set([worker1.key, worker2.key, worker3.key]), + get_keys(connection=worker1.connection) + ) + + # Calling get_keys without arguments raises an exception + self.assertRaises(ValueError, get_keys) + + unregister(worker1) + unregister(worker2) + unregister(worker3)