diff --git a/rq/worker.py b/rq/worker.py index c9faa7e..c015dfa 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -17,7 +17,7 @@ from .connections import get_current_connection from .job import Job, Status from .utils import make_colorizer, utcnow, utcformat from .logutils import setup_loghandlers -from .exceptions import NoQueueError, UnpickleError, DequeueTimeout +from .exceptions import NoQueueError, DequeueTimeout from .timeouts import death_penalty_after from .version import VERSION from rq.compat import text_type, as_text @@ -91,6 +91,7 @@ class Worker(object): worker = cls([], name, connection=connection) queues = as_text(connection.hget(worker.key, 'queues')) worker._state = connection.hget(worker.key, 'state') or '?' + worker._job_id = connection.hget(worker.key, 'current_job') or None if queues: worker.queues = [Queue(queue, connection=connection) for queue in queues.split(',')] @@ -213,14 +214,50 @@ class Worker(object): p.expire(self.key, 60) p.execute() - def set_state(self, new_state): - self._state = new_state - self.connection.hset(self.key, 'state', new_state) + def set_state(self, state, pipeline=None): + self._state = state + connection = pipeline if pipeline is not None else self.connection + connection.hset(self.key, 'state', state) + + def _set_state(self, state): + """Raise a DeprecationWarning if ``worker.state = X`` is used""" + raise DeprecationWarning( + "worker.state is deprecated, use worker.set_state() instead." + ) + self.set_state(state) def get_state(self): return self._state - state = property(get_state, set_state) + def _get_state(self): + """Raise a DeprecationWarning if ``worker.state == X`` is used""" + raise DeprecationWarning( + "worker.state is deprecated, use worker.get_state() instead." + ) + return self.get_state() + + state = property(_get_state, _set_state) + + def set_current_job_id(self, job_id, pipeline=None): + connection = pipeline if pipeline is not None else self.connection + + if job_id is None: + connection.hdel(self.key, 'current_job') + else: + connection.hset(self.key, 'current_job', job_id) + + def get_current_job_id(self, pipeline=None): + connection = pipeline if pipeline is not None else self.connection + return as_text(connection.hget(self.key, 'current_job')) + + def get_current_job(self): + """Returns the job id of the currently executing job.""" + job_id = self.get_current_job_id() + + if job_id is None: + return None + + return Job.fetch(job_id, self.connection) @property def stopped(self): @@ -263,7 +300,7 @@ class Worker(object): # If shutdown is requested in the middle of a job, wait until # finish before shutting down - if self.state == 'busy': + if self.get_state() == 'busy': self._stopped = True self.log.debug('Stopping after current horse is finished. ' 'Press Ctrl+C again for a cold shutdown.') @@ -289,13 +326,13 @@ class Worker(object): did_perform_work = False self.register_birth() self.log.info('RQ worker started, version %s' % VERSION) - self.state = 'starting' + self.set_state('starting') try: while True: if self.stopped: self.log.info('Stopping on request.') break - self.state = 'idle' + self.set_state('idle') qnames = self.queue_names() self.procline('Listening on %s' % ','.join(qnames)) self.log.info('') @@ -309,9 +346,11 @@ class Worker(object): except StopRequested: break - self.state = 'busy' + self.set_state('busy') job, queue = result + self.set_current_job_id(job.id) + # Use the public setter here, to immediately update Redis job.status = Status.STARTED self.log.info('%s: %s (%s)' % (green(queue.name), @@ -320,6 +359,8 @@ class Worker(object): self.heartbeat((job.timeout or 180) + 60) self.execute_job(job) self.heartbeat() + self.set_current_job_id(None) + if job.status == Status.FINISHED: queue.enqueue_dependents(job) diff --git a/tests/test_worker.py b/tests/test_worker.py index dca6074..a7fa320 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,10 +1,10 @@ import os -from time import sleep from tests import RQTestCase, slow -from tests.fixtures import say_hello, div_by_zero, do_nothing, create_file, \ +from tests.fixtures import say_hello, div_by_zero, create_file, \ create_file_after_timeout from tests.helpers import strip_microseconds from rq import Queue, Worker, get_failed_queue +from rq.compat import as_text from rq.job import Job, Status @@ -186,7 +186,7 @@ class TestWorker(RQTestCase): # TODO: Having to do the manual refresh() here is really ugly! res.refresh() - self.assertIn('JobTimeoutException', res.exc_info) + self.assertIn('JobTimeoutException', as_text(res.exc_info)) def test_worker_sets_result_ttl(self): """Ensure that Worker properly sets result_ttl for individual jobs.""" @@ -250,3 +250,17 @@ class TestWorker(RQTestCase): w.work(burst=True) job = Job.fetch(job.id) self.assertNotEqual(job.status, Status.FINISHED) + + def test_get_current_job(self): + """Ensure worker.get_current_job() works properly""" + q = Queue() + worker = Worker([q]) + job = q.enqueue_call(say_hello) + + self.assertEqual(self.testconn.hget(worker.key, 'current_job'), None) + worker.set_current_job_id(job.id) + self.assertEqual( + worker.get_current_job_id(), + as_text(self.testconn.hget(worker.key, 'current_job')) + ) + self.assertEqual(worker.get_current_job(), job)