diff --git a/rq/utils.py b/rq/utils.py index 1d8ab5b..573c4aa 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -261,3 +261,16 @@ def get_version(connection): except ResponseError: # fakeredis doesn't implement Redis' INFO command version_string = "5.0.9" return StrictVersion('.'.join(version_string.split('.')[:3])) + + +def ceildiv(a, b): + """Ceiling division. Returns the ceiling of the quotient of a division operation""" + return -(-a // b) + + +def split_list(a_list, segment_size): + """ + Splits a list into multiple smaller lists having size `segment_size` + """ + for i in range(0, len(a_list), segment_size): + yield a_list[i:i + segment_size] diff --git a/rq/worker_registration.py b/rq/worker_registration.py index 3944bc7..6f24bcc 100644 --- a/rq/worker_registration.py +++ b/rq/worker_registration.py @@ -1,8 +1,10 @@ from .compat import as_text +from rq.utils import split_list WORKERS_BY_QUEUE_KEY = 'rq:workers:%s' REDIS_WORKER_KEYS = 'rq:workers' +MAX_KEYS = 1000 def register(worker, pipeline=None): @@ -62,6 +64,7 @@ def clean_worker_registry(queue): invalid_keys.append(keys[i]) if invalid_keys: - pipeline.srem(WORKERS_BY_QUEUE_KEY % queue.name, *invalid_keys) - pipeline.srem(REDIS_WORKER_KEYS, *invalid_keys) - pipeline.execute() + for invalid_subset in split_list(invalid_keys, MAX_KEYS): + pipeline.srem(WORKERS_BY_QUEUE_KEY % queue.name, *invalid_subset) + pipeline.srem(REDIS_WORKER_KEYS, *invalid_subset) + pipeline.execute() diff --git a/tests/test_utils.py b/tests/test_utils.py index f76be3c..69db3bd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,8 @@ from distutils.version import StrictVersion from redis import Redis from tests import RQTestCase, fixtures -from rq.utils import backend_class, ensure_list, first, get_version, is_nonstring_iterable, parse_timeout, utcparse +from rq.utils import backend_class, ensure_list, first, get_version, is_nonstring_iterable, parse_timeout, utcparse, \ + split_list, ceildiv from rq.exceptions import TimeoutFormatError @@ -88,4 +89,27 @@ class TestUtils(RQTestCase): class DummyRedis(Redis): def info(*args): return {'redis_version': '3.0.7.9'} - self.assertEqual(get_version(DummyRedis()), StrictVersion('3.0.7')) \ No newline at end of file + self.assertEqual(get_version(DummyRedis()), StrictVersion('3.0.7')) + + def test_ceildiv_even(self): + """When a number is evenly divisible by another ceildiv returns the quotient""" + dividend = 12 + divisor = 4 + self.assertEqual(ceildiv(dividend, divisor), dividend // divisor) + + def test_ceildiv_uneven(self): + """When a number is not evenly divisible by another ceildiv returns the quotient plus one""" + dividend = 13 + divisor = 4 + self.assertEqual(ceildiv(dividend, divisor), dividend // divisor + 1) + + def test_split_list(self): + """Ensure split_list works properly""" + BIG_LIST_SIZE = 42 + SEGMENT_SIZE = 5 + + big_list = ['1'] * BIG_LIST_SIZE + small_lists = list(split_list(big_list, SEGMENT_SIZE)) + + expected_small_list_count = ceildiv(BIG_LIST_SIZE, SEGMENT_SIZE) + self.assertEqual(len(small_lists), expected_small_list_count) diff --git a/tests/test_worker_registration.py b/tests/test_worker_registration.py index 177dc7b..2450d64 100644 --- a/tests/test_worker_registration.py +++ b/tests/test_worker_registration.py @@ -1,4 +1,6 @@ +from rq.utils import ceildiv from tests import RQTestCase +from mock.mock import patch from rq import Queue, Worker from rq.worker_registration import (clean_worker_registry, get_keys, register, @@ -87,3 +89,30 @@ class TestWorkerRegistry(RQTestCase): clean_worker_registry(queue) self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key)) self.assertFalse(redis.sismember(REDIS_WORKER_KEYS, worker.key)) + + def test_clean_large_registry(self): + """ + clean_registry() splits invalid_keys into multiple lists for set removal to avoid sending more than redis can + receive + """ + MAX_WORKERS = 41 + MAX_KEYS = 37 + # srem is called twice per invalid key batch: once for WORKERS_BY_QUEUE_KEY; once for REDIS_WORKER_KEYS + SREM_CALL_COUNT = 2 + + queue = Queue(name='foo') + for i in range(MAX_WORKERS): + worker = Worker([queue]) + register(worker) + + with patch('rq.worker_registration.MAX_KEYS', MAX_KEYS), \ + patch.object(queue.connection, 'pipeline', wraps=queue.connection.pipeline) as pipeline_mock: + # clean_worker_registry creates a pipeline with a context manager. Configure the mock using the context + # manager entry method __enter__ + pipeline_mock.return_value.__enter__.return_value.srem.return_value = None + pipeline_mock.return_value.__enter__.return_value.execute.return_value = [0] * MAX_WORKERS + + clean_worker_registry(queue) + + expected_call_count = (ceildiv(MAX_WORKERS, MAX_KEYS)) * SREM_CALL_COUNT + self.assertEqual(pipeline_mock.return_value.__enter__.return_value.srem.call_count, expected_call_count)