Merge pull request #742 from th3hamm0r/739_fix_race_condition

#739 fix race condition
main
Selwin Ong 8 years ago committed by GitHub
commit cfd3b507d2

@ -543,7 +543,7 @@ class Job(object):
forever) forever)
""" """
if ttl == 0: if ttl == 0:
self.delete(remove_from_queue=remove_from_queue) self.delete(pipeline=pipeline, remove_from_queue=remove_from_queue)
elif not ttl: elif not ttl:
return return
elif ttl > 0: elif ttl > 0:

@ -299,24 +299,52 @@ class Queue(object):
return job return job
def enqueue_dependents(self, job, pipeline=None): def enqueue_dependents(self, job, pipeline=None):
"""Enqueues all jobs in the given job's dependents set and clears it.""" """Enqueues all jobs in the given job's dependents set and clears it.
# TODO: can probably be pipelined
When called without a pipeline, this method uses WATCH/MULTI/EXEC.
If you pass a pipeline, only MULTI is called. The rest is up to the
caller.
"""
from .registry import DeferredJobRegistry from .registry import DeferredJobRegistry
pipe = pipeline if pipeline is not None else self.connection._pipeline()
dependents_key = job.dependents_key
while True: while True:
job_id = as_text(self.connection.spop(job.dependents_key)) try:
if job_id is None: # if a pipeline is passed, the caller is responsible for calling WATCH
break # to ensure all jobs are enqueued
dependent = self.job_class.fetch(job_id, connection=self.connection) if pipeline is None:
pipe.watch(dependents_key)
dependent_jobs = [self.job_class.fetch(as_text(job_id), connection=self.connection)
for job_id in pipe.smembers(dependents_key)]
pipe.multi()
for dependent in dependent_jobs:
registry = DeferredJobRegistry(dependent.origin, self.connection) registry = DeferredJobRegistry(dependent.origin, self.connection)
with self.connection._pipeline() as pipeline: registry.remove(dependent, pipeline=pipe)
registry.remove(dependent, pipeline=pipeline)
if dependent.origin == self.name: if dependent.origin == self.name:
self.enqueue_job(dependent, pipeline=pipeline) self.enqueue_job(dependent, pipeline=pipe)
else: else:
queue = Queue(name=dependent.origin, connection=self.connection) queue = Queue(name=dependent.origin, connection=self.connection)
queue.enqueue_job(dependent, pipeline=pipeline) queue.enqueue_job(dependent, pipeline=pipe)
pipeline.execute()
pipe.delete(dependents_key)
if pipeline is None:
pipe.execute()
break
except WatchError:
if pipeline is None:
continue
else:
# if the pipeline comes from the caller, we re-raise the
# exception as it it the responsibility of the caller to
# handle it
raise
def pop_job_id(self): def pop_job_id(self):
"""Pops a given job ID from this Redis queue.""" """Pops a given job ID from this Redis queue."""

@ -14,6 +14,8 @@ import traceback
import warnings import warnings
from datetime import timedelta from datetime import timedelta
from redis import WatchError
from rq.compat import as_text, string_types, text_type from rq.compat import as_text, string_types, text_type
from .compat import PY2 from .compat import PY2
@ -535,15 +537,9 @@ class Worker(object):
# Job completed and its ttl has expired # Job completed and its ttl has expired
break break
if job_status not in [JobStatus.FINISHED, JobStatus.FAILED]: if job_status not in [JobStatus.FINISHED, JobStatus.FAILED]:
with self.connection._pipeline() as pipeline:
self.handle_job_failure( self.handle_job_failure(
job=job, job=job
pipeline=pipeline
) )
try:
pipeline.execute()
except Exception:
pass
# Unhandled failure: move the job to the failed queue # Unhandled failure: move the job to the failed queue
self.log.warning( self.log.warning(
@ -630,8 +626,7 @@ class Worker(object):
def handle_job_failure( def handle_job_failure(
self, self,
job, job,
started_job_registry=None, started_job_registry=None
pipeline=None
): ):
"""Handles the failure or an executing job by: """Handles the failure or an executing job by:
1. Setting the job status to failed 1. Setting the job status to failed
@ -639,6 +634,7 @@ class Worker(object):
3. Setting the workers current job to None 3. Setting the workers current job to None
""" """
with self.connection._pipeline() as pipeline:
if started_job_registry is None: if started_job_registry is None:
started_job_registry = StartedJobRegistry( started_job_registry = StartedJobRegistry(
job.origin, job.origin,
@ -647,32 +643,32 @@ class Worker(object):
job.set_status(JobStatus.FAILED, pipeline=pipeline) job.set_status(JobStatus.FAILED, pipeline=pipeline)
started_job_registry.remove(job, pipeline=pipeline) started_job_registry.remove(job, pipeline=pipeline)
self.set_current_job_id(None, pipeline=pipeline) self.set_current_job_id(None, pipeline=pipeline)
try:
pipeline.execute()
except Exception:
# Ensure that custom exception handlers are called
# even if Redis is down
pass
def perform_job(self, job, queue): def handle_job_success(
"""Performs the actual work of a job. Will/should only be called self,
inside the work horse's process. job,
""" queue,
self.prepare_job_execution(job) started_job_registry
):
with self.connection._pipeline() as pipeline: with self.connection._pipeline() as pipeline:
while True:
push_connection(self.connection)
started_job_registry = StartedJobRegistry(job.origin, self.connection)
try: try:
with self.death_penalty_class(job.timeout or self.queue_class.DEFAULT_TIMEOUT): # if dependencies are inserted after enqueue_dependents
rv = job.perform() # a WatchError is thrown by execute()
pipeline.watch(job.dependents_key)
# Pickle the result in the same try-except block since we need # enqueue_dependents calls multi() on the pipeline!
# to use the same exc handling when pickling fails queue.enqueue_dependents(job, pipeline=pipeline)
job._result = rv
self.set_current_job_id(None, pipeline=pipeline) self.set_current_job_id(None, pipeline=pipeline)
result_ttl = job.get_result_ttl(self.default_result_ttl) result_ttl = job.get_result_ttl(self.default_result_ttl)
if result_ttl != 0: if result_ttl != 0:
job.ended_at = utcnow()
job.set_status(JobStatus.FINISHED, pipeline=pipeline) job.set_status(JobStatus.FINISHED, pipeline=pipeline)
job.save(pipeline=pipeline) job.save(pipeline=pipeline)
@ -680,25 +676,45 @@ class Worker(object):
self.connection) self.connection)
finished_job_registry.add(job, result_ttl, pipeline) finished_job_registry.add(job, result_ttl, pipeline)
queue.enqueue_dependents(job, pipeline=pipeline)
job.cleanup(result_ttl, pipeline=pipeline, job.cleanup(result_ttl, pipeline=pipeline,
remove_from_queue=False) remove_from_queue=False)
started_job_registry.remove(job, pipeline=pipeline) started_job_registry.remove(job, pipeline=pipeline)
pipeline.execute() pipeline.execute()
break
except WatchError:
continue
def perform_job(self, job, queue):
"""Performs the actual work of a job. Will/should only be called
inside the work horse's process.
"""
self.prepare_job_execution(job)
push_connection(self.connection)
started_job_registry = StartedJobRegistry(job.origin, self.connection)
try:
with self.death_penalty_class(job.timeout or self.queue_class.DEFAULT_TIMEOUT):
rv = job.perform()
job.ended_at = utcnow()
# Pickle the result in the same try-except block since we need
# to use the same exc handling when pickling fails
job._result = rv
self.handle_job_success(
job=job,
queue=queue,
started_job_registry=started_job_registry
)
except Exception: except Exception:
self.handle_job_failure( self.handle_job_failure(
job=job, job=job,
started_job_registry=started_job_registry, started_job_registry=started_job_registry
pipeline=pipeline
) )
try:
pipeline.execute()
except Exception:
# Ensure that custom exception handlers are called
# even if Redis is down
pass
self.handle_exception(job, *sys.exc_info()) self.handle_exception(job, *sys.exc_info())
return False return False
@ -710,6 +726,7 @@ class Worker(object):
log_result = "{0!r}".format(as_text(text_type(rv))) log_result = "{0!r}".format(as_text(text_type(rv)))
self.log.debug('Result: {0}'.format(yellow(log_result))) self.log.debug('Result: {0}'.format(yellow(log_result)))
result_ttl = job.get_result_ttl(self.default_result_ttl)
if result_ttl == 0: if result_ttl == 0:
self.log.info('Result discarded immediately') self.log.info('Result discarded immediately')
elif result_ttl > 0: elif result_ttl > 0:

@ -11,6 +11,8 @@ import time
from multiprocessing import Process from multiprocessing import Process
import subprocess import subprocess
import mock
from tests import RQTestCase, slow from tests import RQTestCase, slow
from tests.fixtures import (create_file, create_file_after_timeout, from tests.fixtures import (create_file, create_file_after_timeout,
div_by_zero, do_nothing, say_hello, say_pid, div_by_zero, do_nothing, say_hello, say_pid,
@ -567,6 +569,40 @@ class TestWorker(RQTestCase):
worker.work(burst=True) worker.work(burst=True)
self.assertEqual(self.testconn.zcard(registry.key), 0) self.assertEqual(self.testconn.zcard(registry.key), 0)
def test_job_dependency_race_condition(self):
"""Dependencies added while the job gets finished shouldn't get lost."""
# This patches the enqueue_dependents to enqueue a new dependency AFTER
# the original code was executed.
orig_enqueue_dependents = Queue.enqueue_dependents
def new_enqueue_dependents(self, job, *args, **kwargs):
orig_enqueue_dependents(self, job, *args, **kwargs)
if hasattr(Queue, '_add_enqueue') and Queue._add_enqueue is not None and Queue._add_enqueue.id == job.id:
Queue._add_enqueue = None
Queue().enqueue_call(say_hello, depends_on=job)
Queue.enqueue_dependents = new_enqueue_dependents
q = Queue()
w = Worker([q])
with mock.patch.object(Worker, 'execute_job', wraps=w.execute_job) as mocked:
parent_job = q.enqueue(say_hello, result_ttl=0)
Queue._add_enqueue = parent_job
job = q.enqueue_call(say_hello, depends_on=parent_job)
w.work(burst=True)
job = Job.fetch(job.id)
self.assertEqual(job.get_status(), JobStatus.FINISHED)
# The created spy checks two issues:
# * before the fix of #739, 2 of the 3 jobs where executed due
# to the race condition
# * during the development another issue was fixed:
# due to a missing pipeline usage in Queue.enqueue_job, the job
# which was enqueued before the "rollback" was executed twice.
# So before that fix the call count was 4 instead of 3
self.assertEqual(mocked.call_count, 3)
def kill_worker(pid, double_kill): def kill_worker(pid, double_kill):
# wait for the worker to be started over on the main process # wait for the worker to be started over on the main process

Loading…
Cancel
Save