diff --git a/rq/job.py b/rq/job.py index ef7a266..bd91f13 100644 --- a/rq/job.py +++ b/rq/job.py @@ -147,9 +147,10 @@ class Job(object): ) return self.get_status() - def set_status(self, status): + def set_status(self, status, pipeline=None): self._status = status - self.connection.hset(self.key, 'status', self._status) + connection = pipeline if pipeline is not None else self.connection + connection.hset(self.key, 'status', self._status) def _set_status(self, status): warnings.warn( diff --git a/rq/queue.py b/rq/queue.py index 86f0406..621942a 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -158,9 +158,10 @@ class Queue(object): if self.job_class.exists(job_id, self.connection): self.connection.rpush(self.key, job_id) - def push_job_id(self, job_id): + def push_job_id(self, job_id, pipeline=None): """Pushes a job ID on the corresponding Redis queue.""" - self.connection.rpush(self.key, job_id) + connection = pipeline if pipeline is not None else self.connection + connection.rpush(self.key, job_id) def enqueue_call(self, func, args=None, kwargs=None, timeout=None, result_ttl=None, description=None, depends_on=None): diff --git a/rq/registry.py b/rq/registry.py new file mode 100644 index 0000000..afa7b5b --- /dev/null +++ b/rq/registry.py @@ -0,0 +1,58 @@ +from .compat import as_text +from .connections import resolve_connection +from .queue import FailedQueue +from .utils import current_timestamp + + +class StartedJobRegistry: + """ + Registry of currently executing jobs. Each queue maintains a StartedJobRegistry. + StartedJobRegistry contains job keys that are currently being executed. + Each key is scored by job's expiration time (datetime started + timeout). + + Jobs are added to registry right before they are executed and removed + right after completion (success or failure). + + Jobs whose score are lower than current time is considered "expired". + """ + + def __init__(self, name='default', connection=None): + self.name = name + self.key = 'rq:wip:%s' % name + self.connection = resolve_connection(connection) + + def add(self, job, timeout, pipeline=None): + """Adds a job to StartedJobRegistry with expiry time of now + timeout.""" + score = current_timestamp() + timeout + if pipeline is not None: + return pipeline.zadd(self.key, score, job.id) + + return self.connection._zadd(self.key, score, job.id) + + def remove(self, job, pipeline=None): + connection = pipeline if pipeline is not None else self.connection + return connection.zrem(self.key, job.id) + + def get_expired_job_ids(self): + """Returns job ids whose score are less than current timestamp.""" + return [as_text(job_id) for job_id in + self.connection.zrangebyscore(self.key, 0, current_timestamp())] + + def get_job_ids(self, start=0, end=-1): + """Returns list of all job ids.""" + self.move_expired_jobs_to_failed_queue() + return [as_text(job_id) for job_id in + self.connection.zrange(self.key, start, end)] + + def move_expired_jobs_to_failed_queue(self): + """Remove expired jobs from registry and add them to FailedQueue.""" + job_ids = self.get_expired_job_ids() + + if job_ids: + failed_queue = FailedQueue(connection=self.connection) + with self.connection.pipeline() as pipeline: + for job_id in job_ids: + failed_queue.push_job_id(job_id, pipeline=pipeline) + pipeline.execute() + + return job_ids diff --git a/rq/utils.py b/rq/utils.py index d875b26..8233ac6 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -8,6 +8,7 @@ terminal colorizing code, originally by Georg Brandl. from __future__ import (absolute_import, division, print_function, unicode_literals) +import calendar import importlib import datetime import logging @@ -202,3 +203,8 @@ def first(iterable, default=None, key=None): return el return default + + +def current_timestamp(): + """Returns current UTC timestamp""" + return calendar.timegm(datetime.datetime.utcnow().utctimetuple()) diff --git a/rq/worker.py b/rq/worker.py index eea5e65..3956507 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -23,6 +23,7 @@ from .queue import get_failed_queue, Queue from .timeouts import UnixSignalDeathPenalty from .utils import import_attribute, make_colorizer, utcformat, utcnow from .version import VERSION +from .registry import StartedJobRegistry try: from procname import setprocname @@ -403,7 +404,7 @@ class Worker(object): self.heartbeat() return result - def heartbeat(self, timeout=0): + def heartbeat(self, timeout=0, pipeline=None): """Specifies a new worker timeout, typically by extending the expiration time of the worker, effectively making this a "heartbeat" to not expire the worker until the timeout passes. @@ -415,7 +416,8 @@ class Worker(object): only larger. """ timeout = max(timeout, self.default_worker_ttl) - self.connection.expire(self.key, timeout) + connection = pipeline if pipeline is not None else self.connection + connection.expire(self.key, timeout) self.log.debug('Sent heartbeat to prevent worker timeout. ' 'Next one should arrive within {0} seconds.'.format(timeout)) @@ -468,27 +470,40 @@ class Worker(object): # constrast to the regular sys.exit() os._exit(int(not success)) - def perform_job(self, job): - """Performs the actual work of a job. Will/should only be called - inside the work horse's process. + def prepare_job_execution(self, job): + """Performs misc bookkeeping like updating states prior to + job execution. """ + timeout = (job.timeout or 180) + 60 - self.set_state('busy') - self.set_current_job_id(job.id) - self.heartbeat((job.timeout or 180) + 60) + with self.connection._pipeline() as pipeline: + self.set_state('busy', pipeline=pipeline) + self.set_current_job_id(job.id, pipeline=pipeline) + self.heartbeat(timeout, pipeline=pipeline) + registry = StartedJobRegistry(job.origin, self.connection) + registry.add(job, timeout, pipeline=pipeline) + job.set_status(Status.STARTED, pipeline=pipeline) + pipeline.execute() self.procline('Processing %s from %s since %s' % ( job.func_name, job.origin, time.time())) + def perform_job(self, job): + """Performs the actual work of a job. Will/should only be called + inside the work horse's process. + """ + self.prepare_job_execution(job) + with self.connection._pipeline() as pipeline: + registry = StartedJobRegistry(job.origin, self.connection) + try: - job.set_status(Status.STARTED) with self.death_penalty_class(job.timeout or self.queue_class.DEFAULT_TIMEOUT): rv = job.perform() - # Pickle the result in the same try-except block since we need to - # use the same exc handling when pickling fails + # Pickle the result in the same try-except block since we need + # to use the same exc handling when pickling fails job._result = rv self.set_current_job_id(None, pipeline=pipeline) @@ -499,12 +514,15 @@ class Worker(object): job._status = Status.FINISHED job.save(pipeline=pipeline) job.cleanup(result_ttl, pipeline=pipeline) + registry.remove(job, pipeline=pipeline) pipeline.execute() except Exception: - # Use the public setter here, to immediately update Redis - job.set_status(Status.FAILED) + job.set_status(Status.FAILED, pipeline=pipeline) + registry.remove(job, pipeline=pipeline) + pipeline.execute() + self.handle_exception(job, *sys.exc_info()) return False diff --git a/tests/test_job_started_registry.py b/tests/test_job_started_registry.py new file mode 100644 index 0000000..addb1db --- /dev/null +++ b/tests/test_job_started_registry.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from rq.job import Job +from rq.queue import FailedQueue, Queue +from rq.utils import current_timestamp +from rq.worker import Worker +from rq.registry import StartedJobRegistry + +from tests import RQTestCase +from tests.fixtures import div_by_zero, say_hello + + +class TestQueue(RQTestCase): + + def setUp(self): + super(TestQueue, self).setUp() + self.registry = StartedJobRegistry(connection=self.testconn) + + def test_add_and_remove(self): + """Adding and removing job to StartedJobRegistry.""" + timestamp = current_timestamp() + job = Job() + + # Test that job is added with the right score + self.registry.add(job, 1000) + self.assertLess(self.testconn.zscore(self.registry.key, job.id), + timestamp + 1002) + + # Ensure that job is properly removed from sorted set + self.registry.remove(job) + self.assertIsNone(self.testconn.zscore(self.registry.key, job.id)) + + def test_get_job_ids(self): + """Getting job ids from StartedJobRegistry.""" + self.testconn.zadd(self.registry.key, 1, 'foo') + self.testconn.zadd(self.registry.key, 10, 'bar') + self.assertEqual(self.registry.get_job_ids(), ['foo', 'bar']) + + def test_get_expired_job_ids(self): + """Getting expired job ids form StartedJobRegistry.""" + timestamp = current_timestamp() + + self.testconn.zadd(self.registry.key, 1, 'foo') + self.testconn.zadd(self.registry.key, timestamp + 10, 'bar') + + self.assertEqual(self.registry.get_expired_job_ids(), ['foo']) + + def test_cleanup(self): + """Moving expired jobs to FailedQueue.""" + failed_queue = FailedQueue(connection=self.testconn) + self.assertTrue(failed_queue.is_empty()) + self.testconn.zadd(self.registry.key, 1, 'foo') + self.registry.move_expired_jobs_to_failed_queue() + self.assertIn('foo', failed_queue.job_ids) + + def test_job_execution(self): + """Job is removed from StartedJobRegistry after execution.""" + registry = StartedJobRegistry(connection=self.testconn) + queue = Queue(connection=self.testconn) + worker = Worker([queue]) + + job = queue.enqueue(say_hello) + + worker.prepare_job_execution(job) + self.assertIn(job.id, registry.get_job_ids()) + + worker.perform_job(job) + self.assertNotIn(job.id, registry.get_job_ids()) + + # Job that fails + job = queue.enqueue(div_by_zero) + + worker.prepare_job_execution(job) + self.assertIn(job.id, registry.get_job_ids()) + + worker.perform_job(job) + self.assertNotIn(job.id, registry.get_job_ids()) diff --git a/tests/test_worker.py b/tests/test_worker.py index 7a9aadf..764cf46 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -7,10 +7,11 @@ import os from rq import get_failed_queue, Queue, Worker, SimpleWorker from rq.compat import as_text from rq.job import Job, Status +from rq.registry import StartedJobRegistry from tests import RQTestCase, slow -from tests.fixtures import (create_file, create_file_after_timeout, div_by_zero, - say_hello, say_pid) +from tests.fixtures import (create_file, create_file_after_timeout, + div_by_zero, say_hello, say_pid) from tests.helpers import strip_microseconds @@ -291,3 +292,18 @@ class TestWorker(RQTestCase): 'Expected at least some work done.') self.assertEquals(job.result, os.getpid(), 'PID mismatch, fork() is not supposed to happen here') + + def test_prepare_job_execution(self): + """Prepare job execution does the necessary bookkeeping.""" + queue = Queue(connection=self.testconn) + job = queue.enqueue(say_hello) + worker = Worker([queue]) + worker.prepare_job_execution(job) + + # Updates working queue + registry = StartedJobRegistry(connection=self.testconn) + self.assertEqual(registry.get_job_ids(), [job.id]) + + # Updates worker statuses + self.assertEqual(worker.state, 'busy') + self.assertEqual(worker.get_current_job_id(), job.id)