diff --git a/rq/registry.py b/rq/registry.py index 0242280..59becfe 100644 --- a/rq/registry.py +++ b/rq/registry.py @@ -4,21 +4,17 @@ from .queue import FailedQueue from .utils import current_timestamp -class StartedJobRegistry: +class BaseRegistry(object): """ - Registry of currently executing jobs. Each queue maintains a StartedJobRegistry. - StartedJobRegistry contains job keys that are currently being executed. - Each key is scored by job's expiration time (datetime started + timeout). + Base implementation of job registry, implemented in Redis sorted set. Each job + is stored as a key in the registry, scored by expiration time (unix timestamp). - Jobs are added to registry right before they are executed and removed - right after completion (success or failure). - - Jobs whose score are lower than current time is considered "expired". + Jobs with scores are lower than current time is considered "expired" and + should be cleaned up. """ def __init__(self, name='default', connection=None): self.name = name - self.key = 'rq:wip:%s' % name self.connection = resolve_connection(connection) def __len__(self): @@ -28,7 +24,7 @@ class StartedJobRegistry: @property def count(self): """Returns the number of jobs in this registry""" - self.move_expired_jobs_to_failed_queue() + self.cleanup() return self.connection.zcard(self.key) def add(self, job, timeout, pipeline=None): @@ -50,11 +46,28 @@ class StartedJobRegistry: def get_job_ids(self, start=0, end=-1): """Returns list of all job ids.""" - self.move_expired_jobs_to_failed_queue() + self.cleanup() return [as_text(job_id) for job_id in self.connection.zrange(self.key, start, end)] - def move_expired_jobs_to_failed_queue(self): + +class StartedJobRegistry(BaseRegistry): + """ + Registry of currently executing jobs. Each queue maintains a + StartedJobRegistry. Jobs in this registry are ones that are currently + being executed. + + Jobs are added to registry right before they are executed and removed + right after completion (success or failure). + + Jobs whose score are lower than current time is considered "expired". + """ + + def __init__(self, name='default', connection=None): + super(StartedJobRegistry, self).__init__(name, connection) + self.key = 'rq:wip:%s' % name + + def cleanup(self): """Remove expired jobs from registry and add them to FailedQueue.""" job_ids = self.get_expired_job_ids() @@ -63,6 +76,7 @@ class StartedJobRegistry: with self.connection.pipeline() as pipeline: for job_id in job_ids: failed_queue.push_job_id(job_id, pipeline=pipeline) + pipeline.zremrangebyscore(self.key, 0, current_timestamp()) pipeline.execute() return job_ids diff --git a/tests/test_job_started_registry.py b/tests/test_job_started_registry.py index f99aafd..eba1c36 100644 --- a/tests/test_job_started_registry.py +++ b/tests/test_job_started_registry.py @@ -33,8 +33,9 @@ class TestQueue(RQTestCase): def test_get_job_ids(self): """Getting job ids from StartedJobRegistry.""" - self.testconn.zadd(self.registry.key, 1, 'foo') - self.testconn.zadd(self.registry.key, 10, 'bar') + timestamp = current_timestamp() + self.testconn.zadd(self.registry.key, timestamp + 10, 'foo') + self.testconn.zadd(self.registry.key, timestamp + 20, 'bar') self.assertEqual(self.registry.get_job_ids(), ['foo', 'bar']) def test_get_expired_job_ids(self): @@ -51,8 +52,9 @@ class TestQueue(RQTestCase): failed_queue = FailedQueue(connection=self.testconn) self.assertTrue(failed_queue.is_empty()) self.testconn.zadd(self.registry.key, 1, 'foo') - self.registry.move_expired_jobs_to_failed_queue() + self.registry.cleanup() self.assertIn('foo', failed_queue.job_ids) + self.assertEqual(self.testconn.zscore(self.registry.key, 'foo'), None) def test_job_execution(self): """Job is removed from StartedJobRegistry after execution.""" @@ -79,7 +81,8 @@ class TestQueue(RQTestCase): def test_get_job_count(self): """StartedJobRegistry returns the right number of job count.""" - self.testconn.zadd(self.registry.key, 1, 'foo') - self.testconn.zadd(self.registry.key, 10, 'bar') + timestamp = current_timestamp() + 10 + self.testconn.zadd(self.registry.key, timestamp, 'foo') + self.testconn.zadd(self.registry.key, timestamp, 'bar') self.assertEqual(self.registry.count, 2) self.assertEqual(len(self.registry), 2)