diff --git a/docs/docs/index.md b/docs/docs/index.md index bbf7841..1c24c6b 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -202,6 +202,21 @@ If you want to execute a function whenever a job completes or fails, RQ provides queue.enqueue(say_hello, on_success=report_success, on_failure=report_failure) ``` +### Callback Class and Callback Timeouts + +_New in version 1.14.0_ + +RQ lets you configure the method and timeout for each callback - success and failure. +To configure callback timeouts, use RQ's +`Callback` object that accepts `func` and `timeout` arguments. For example: + +```python +from rq import Callback +queue.enqueue(say_hello, + on_success=Callback(report_success), # default callback timeout (60 seconds) + on_failure=Callback(report_failure, timeout=10)) # 10 seconds timeout +``` + ### Success Callback Success callbacks must be a function that accepts `job`, `connection` and `result` arguments. diff --git a/rq/__init__.py b/rq/__init__.py index ec635d7..0ab7065 100644 --- a/rq/__init__.py +++ b/rq/__init__.py @@ -1,7 +1,7 @@ # flake8: noqa from .connections import Connection, get_current_connection, pop_connection, push_connection -from .job import cancel_job, get_current_job, requeue_job, Retry +from .job import cancel_job, get_current_job, requeue_job, Retry, Callback from .queue import Queue from .version import VERSION from .worker import SimpleWorker, Worker diff --git a/rq/cli/cli.py b/rq/cli/cli.py index a2851aa..27058e8 100755 --- a/rq/cli/cli.py +++ b/rq/cli/cli.py @@ -105,7 +105,8 @@ def empty(cli_config, all, queues, serializer, **options): if all: queues = cli_config.queue_class.all( - connection=cli_config.connection, job_class=cli_config.job_class, serializer=serializer + connection=cli_config.connection, job_class=cli_config.job_class, + death_penalty_class=cli_config.death_penalty_class, serializer=serializer ) else: queues = [ diff --git a/rq/cli/helpers.py b/rq/cli/helpers.py index 0f87d22..d7238ed 100644 --- a/rq/cli/helpers.py +++ b/rq/cli/helpers.py @@ -13,7 +13,8 @@ from shutil import get_terminal_size import click from redis import Redis from redis.sentinel import Sentinel -from rq.defaults import DEFAULT_CONNECTION_CLASS, DEFAULT_JOB_CLASS, DEFAULT_QUEUE_CLASS, DEFAULT_WORKER_CLASS +from rq.defaults import DEFAULT_CONNECTION_CLASS, DEFAULT_JOB_CLASS, DEFAULT_QUEUE_CLASS, DEFAULT_WORKER_CLASS, \ + DEFAULT_DEATH_PENALTY_CLASS from rq.logutils import setup_loghandlers from rq.utils import import_attribute, parse_timeout from rq.worker import WorkerStatus @@ -302,6 +303,7 @@ class CliConfig: config=None, worker_class=DEFAULT_WORKER_CLASS, job_class=DEFAULT_JOB_CLASS, + death_penalty_class=DEFAULT_DEATH_PENALTY_CLASS, queue_class=DEFAULT_QUEUE_CLASS, connection_class=DEFAULT_CONNECTION_CLASS, path=None, @@ -325,6 +327,11 @@ class CliConfig: except (ImportError, AttributeError) as exc: raise click.BadParameter(str(exc), param_hint='--job-class') + try: + self.death_penalty_class = import_attribute(death_penalty_class) + except (ImportError, AttributeError) as exc: + raise click.BadParameter(str(exc), param_hint='--death-penalty-class') + try: self.queue_class = import_attribute(queue_class) except (ImportError, AttributeError) as exc: diff --git a/rq/defaults.py b/rq/defaults.py index bd50489..2a3d57a 100644 --- a/rq/defaults.py +++ b/rq/defaults.py @@ -88,3 +88,9 @@ DEFAULT_LOGGING_FORMAT = '%(asctime)s %(message)s' Uses Python's default attributes as defined https://docs.python.org/3/library/logging.html#logrecord-attributes """ + + +DEFAULT_DEATH_PENALTY_CLASS = 'rq.timeouts.UnixSignalDeathPenalty' +""" The path for the default Death Penalty class to use. +Defaults to the `UnixSignalDeathPenalty` class within the `rq.timeouts` module +""" \ No newline at end of file diff --git a/rq/job.py b/rq/job.py index 006283c..07ec6cb 100644 --- a/rq/job.py +++ b/rq/job.py @@ -1,16 +1,18 @@ import inspect import json +import logging import warnings import zlib import asyncio -from collections.abc import Iterable from datetime import datetime, timedelta, timezone from enum import Enum from redis import WatchError -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, Type from uuid import uuid4 +from .defaults import CALLBACK_TIMEOUT +from .timeouts import JobTimeoutException, BaseDeathPenalty if TYPE_CHECKING: from .results import Result @@ -36,6 +38,8 @@ from .utils import ( utcnow, ) +logger = logging.getLogger("rq.job") + class JobStatus(str, Enum): """The Status of Job within its lifecycle at any given time.""" @@ -153,8 +157,8 @@ class Job: failure_ttl: Optional[int] = None, serializer=None, *, - on_success: Optional[Callable[..., Any]] = None, - on_failure: Optional[Callable[..., Any]] = None + on_success: Optional[Union['Callback', Callable[..., Any]]] = None, + on_failure: Optional[Union['Callback', Callable[..., Any]]] = None ) -> 'Job': """Creates a new Job instance for the given function, arguments, and keyword arguments. @@ -234,14 +238,20 @@ class Job: job._kwargs = kwargs if on_success: - if not inspect.isfunction(on_success) and not inspect.isbuiltin(on_success): - raise ValueError('on_success callback must be a function') - job._success_callback_name = '{0}.{1}'.format(on_success.__module__, on_success.__qualname__) + if not isinstance(on_success, Callback): + warnings.warn('Passing a `Callable` `on_success` is deprecated, pass `Callback` instead', + DeprecationWarning) + on_success = Callback(on_success) # backward compatibility + job._success_callback_name = on_success.name + job._success_callback_timeout = on_success.timeout if on_failure: - if not inspect.isfunction(on_failure) and not inspect.isbuiltin(on_failure): - raise ValueError('on_failure callback must be a function') - job._failure_callback_name = '{0}.{1}'.format(on_failure.__module__, on_failure.__qualname__) + if not isinstance(on_failure, Callback): + warnings.warn('Passing a `Callable` `on_failure` is deprecated, pass `Callback` instead', + DeprecationWarning) + on_failure = Callback(on_failure) # backward compatibility + job._failure_callback_name = on_failure.name + job._failure_callback_timeout = on_failure.timeout # Extra meta data job.description = description or job.get_call_string() @@ -401,6 +411,13 @@ class Job: return self._success_callback + @property + def success_callback_timeout(self) -> int: + if self._success_callback_timeout is None: + return CALLBACK_TIMEOUT + + return self._success_callback_timeout + @property def failure_callback(self): if self._failure_callback is UNEVALUATED: @@ -411,6 +428,13 @@ class Job: return self._failure_callback + @property + def failure_callback_timeout(self) -> int: + if self._failure_callback_timeout is None: + return CALLBACK_TIMEOUT + + return self._failure_callback_timeout + def _deserialize_data(self): """Deserializes the Job `data` into a tuple. This includes the `_func_name`, `_instance`, `_args` and `_kwargs` @@ -580,6 +604,8 @@ class Job: self._result = None self._exc_info = None self.timeout: Optional[float] = None + self._success_callback_timeout: Optional[int] = None + self._failure_callback_timeout: Optional[int] = None self.result_ttl: Optional[int] = None self.failure_ttl: Optional[int] = None self.ttl: Optional[int] = None @@ -867,9 +893,15 @@ class Job: if obj.get('success_callback_name'): self._success_callback_name = obj.get('success_callback_name').decode() + if 'success_callback_timeout' in obj: + self._success_callback_timeout = int(obj.get('success_callback_timeout')) + if obj.get('failure_callback_name'): self._failure_callback_name = obj.get('failure_callback_name').decode() + if 'failure_callback_timeout' in obj: + self._failure_callback_timeout = int(obj.get('failure_callback_timeout')) + dep_ids = obj.get('dependency_ids') dep_id = obj.get('dependency_id') # for backwards compatibility self._dependency_ids = json.loads(dep_ids.decode()) if dep_ids else [dep_id.decode()] if dep_id else [] @@ -947,6 +979,10 @@ class Job: obj['exc_info'] = zlib.compress(str(self._exc_info).encode('utf-8')) if self.timeout is not None: obj['timeout'] = self.timeout + if self._success_callback_timeout is not None: + obj['success_callback_timeout'] = self._success_callback_timeout + if self._failure_callback_timeout is not None: + obj['failure_callback_timeout'] = self._failure_callback_timeout if self.result_ttl is not None: obj['result_ttl'] = self.result_ttl if self.failure_ttl is not None: @@ -1308,6 +1344,35 @@ class Job: self.origin, connection=self.connection, job_class=self.__class__, serializer=self.serializer ) + def execute_success_callback(self, death_penalty_class: Type[BaseDeathPenalty], result: Any): + """Executes success_callback for a job. + with timeout . + + Args: + death_penalty_class (Type[BaseDeathPenalty]): The penalty class to use for timeout + result (Any): The job's result. + """ + if not self.success_callback: + return + + logger.debug('Running success callbacks for %s', self.id) + with death_penalty_class(self.success_callback_timeout, JobTimeoutException, job_id=self.id): + self.success_callback(self, self.connection, result) + + def execute_failure_callback(self, death_penalty_class: Type[BaseDeathPenalty], *exc_info): + """Executes failure_callback with possible timeout + """ + if not self.failure_callback: + return + + logger.debug('Running failure callbacks for %s', self.id) + try: + with death_penalty_class(self.failure_callback_timeout, JobTimeoutException, job_id=self.id): + self.failure_callback(self, self.connection, *exc_info) + except Exception: # noqa + logger.exception(f'Job {self.id}: error while executing failure callback') + raise + def _handle_success(self, result_ttl: int, pipeline: 'Pipeline'): """Saves and cleanup job after successful execution""" # self.log.debug('Setting job %s status to finished', job.id) @@ -1507,3 +1572,16 @@ class Retry: self.max = max self.intervals = intervals + + +class Callback: + def __init__(self, func: Callable[..., Any], timeout: Optional[Any] = None): + if not inspect.isfunction(func) and not inspect.isbuiltin(func): + raise ValueError('Callback func must be a function') + + self.func = func + self.timeout = parse_timeout(timeout) if timeout else CALLBACK_TIMEOUT + + @property + def name(self) -> str: + return '{0}.{1}'.format(self.func.__module__, self.func.__qualname__) diff --git a/rq/queue.py b/rq/queue.py index 985e02e..77a6f3e 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -10,6 +10,8 @@ from typing import TYPE_CHECKING, Dict, List, Any, Callable, Optional, Tuple, Ty from redis import WatchError +from .timeouts import BaseDeathPenalty, UnixSignalDeathPenalty + if TYPE_CHECKING: from redis import Redis from redis.client import Pipeline @@ -62,13 +64,15 @@ class EnqueueData( @total_ordering class Queue: job_class: Type['Job'] = Job + death_penalty_class: Type[BaseDeathPenalty] = UnixSignalDeathPenalty DEFAULT_TIMEOUT: int = 180 # Default timeout seconds. redis_queue_namespace_prefix: str = 'rq:queue:' redis_queues_keys: str = 'rq:queues' @classmethod def all( - cls, connection: Optional['Redis'] = None, job_class: Optional[Type['Job']] = None, serializer=None + cls, connection: Optional['Redis'] = None, job_class: Optional[Type['Job']] = None, + serializer=None, death_penalty_class: Optional[Type[BaseDeathPenalty]] = None ) -> List['Queue']: """Returns an iterable of all Queues. @@ -76,6 +80,7 @@ class Queue: connection (Optional[Redis], optional): The Redis Connection. Defaults to None. job_class (Optional[Job], optional): The Job class to use. Defaults to None. serializer (optional): The serializer to use. Defaults to None. + death_penalty_class (Optional[Job], optional): The Death Penalty class to use. Defaults to None. Returns: queues (List[Queue]): A list of all queues. @@ -84,7 +89,8 @@ class Queue: def to_queue(queue_key): return cls.from_queue_key( - as_text(queue_key), connection=connection, job_class=job_class, serializer=serializer + as_text(queue_key), connection=connection, job_class=job_class, + serializer=serializer, death_penalty_class=death_penalty_class ) all_registerd_queues = connection.smembers(cls.redis_queues_keys) @@ -96,8 +102,9 @@ class Queue: cls, queue_key: str, connection: Optional['Redis'] = None, - job_class: Optional['Job'] = None, + job_class: Optional[Type['Job']] = None, serializer: Any = None, + death_penalty_class: Optional[Type[BaseDeathPenalty]] = None, ) -> 'Queue': """Returns a Queue instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Queues by their @@ -108,6 +115,7 @@ class Queue: connection (Optional[Redis], optional): Redis connection. Defaults to None. job_class (Optional[Job], optional): Job class. Defaults to None. serializer (Any, optional): Serializer. Defaults to None. + death_penalty_class (Optional[BaseDeathPenalty], optional): Death penalty class. Defaults to None. Raises: ValueError: If the queue_key doesn't start with the defined prefix @@ -119,7 +127,8 @@ class Queue: if not queue_key.startswith(prefix): raise ValueError('Not a valid RQ queue key: {0}'.format(queue_key)) name = queue_key[len(prefix):] - return cls(name, connection=connection, job_class=job_class, serializer=serializer) + return cls(name, connection=connection, job_class=job_class, serializer=serializer, + death_penalty_class=death_penalty_class) def __init__( self, @@ -129,6 +138,7 @@ class Queue: is_async: bool = True, job_class: Union[str, Type['Job'], None] = None, serializer: Any = None, + death_penalty_class: Type[BaseDeathPenalty] = UnixSignalDeathPenalty, **kwargs, ): """Initializes a Queue object. @@ -141,6 +151,7 @@ class Queue: If `is_async` is false, jobs will run on the same process from where it was called. Defaults to True. job_class (Union[str, 'Job', optional): Job class or a string referencing the Job class path. Defaults to None. serializer (Any, optional): Serializer. Defaults to None. + death_penalty_class (Type[BaseDeathPenalty, optional): Job class or a string referencing the Job class path. Defaults to UnixSignalDeathPenalty. """ self.connection = resolve_connection(connection) prefix = self.redis_queue_namespace_prefix @@ -159,6 +170,7 @@ class Queue: if isinstance(job_class, str): job_class = import_attribute(job_class) self.job_class = job_class + self.death_penalty_class = death_penalty_class self.serializer = resolve_serializer(serializer) self.redis_server_version: Optional[Tuple[int, int, int]] = None @@ -1202,6 +1214,7 @@ class Queue: connection: Optional['Redis'] = None, job_class: Optional['Job'] = None, serializer: Any = None, + death_penalty_class: Optional[Type[BaseDeathPenalty]] = None, ) -> Tuple['Job', 'Queue']: """Class method returning the job_class instance at the front of the given set of Queues, where the order of the queues is important. @@ -1217,8 +1230,9 @@ class Queue: queues (List[Queue]): List of queue objects timeout (Optional[int]): Timeout for the LPOP connection (Optional[Redis], optional): Redis Connection. Defaults to None. - job_class (Optional[Job], optional): The job classification. Defaults to None. + job_class (Optional[Type[Job]], optional): The job class. Defaults to None. serializer (Any, optional): Serializer to use. Defaults to None. + death_penalty_class (Optional[Type[BaseDeathPenalty]], optional): The death penalty class. Defaults to None. Raises: e: Any exception @@ -1234,7 +1248,8 @@ class Queue: if result is None: return None queue_key, job_id = map(as_text, result) - queue = cls.from_queue_key(queue_key, connection=connection, job_class=job_class, serializer=serializer) + queue = cls.from_queue_key(queue_key, connection=connection, job_class=job_class, + serializer=serializer, death_penalty_class=death_penalty_class) try: job = job_class.fetch(job_id, connection=connection, serializer=serializer) except NoSuchJobError: diff --git a/rq/registry.py b/rq/registry.py index f7534fd..2bd874c 100644 --- a/rq/registry.py +++ b/rq/registry.py @@ -7,7 +7,7 @@ import time from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, List, Optional, Type, Union -from .timeouts import JobTimeoutException, UnixSignalDeathPenalty +from .timeouts import JobTimeoutException, UnixSignalDeathPenalty, BaseDeathPenalty if TYPE_CHECKING: from redis import Redis @@ -33,6 +33,7 @@ class BaseRegistry: """ job_class = Job + death_penalty_class = UnixSignalDeathPenalty key_template = 'rq:registry:{0}' def __init__( @@ -42,6 +43,7 @@ class BaseRegistry: job_class: Optional[Type['Job']] = None, queue: Optional['Queue'] = None, serializer: Any = None, + death_penalty_class: Optional[Type[BaseDeathPenalty]] = None, ): if queue: self.name = queue.name @@ -54,6 +56,7 @@ class BaseRegistry: self.key = self.key_template.format(self.name) self.job_class = backend_class(self, 'job_class', override=job_class) + self.death_penalty_class = backend_class(self, 'death_penalty_class', override=death_penalty_class) def __len__(self): """Returns the number of jobs in this registry""" @@ -210,7 +213,6 @@ class StartedJobRegistry(BaseRegistry): """ key_template = 'rq:wip:{0}' - death_penalty_class = UnixSignalDeathPenalty def cleanup(self, timestamp: Optional[float] = None): """Remove abandoned jobs from registry and add them to FailedJobRegistry. @@ -235,13 +237,8 @@ class StartedJobRegistry(BaseRegistry): except NoSuchJobError: continue - if job.failure_callback: - try: - with self.death_penalty_class(CALLBACK_TIMEOUT, JobTimeoutException, job_id=job.id): - job.failure_callback(job, self.connection, - AbandonedJobError, AbandonedJobError(), traceback.extract_stack()) - except: # noqa - logger.exception('Registry %s: error while executing failure callback', self.key) + job.execute_failure_callback(self.death_penalty_class, AbandonedJobError, AbandonedJobError(), + traceback.extract_stack()) retry = job.retries_left and job.retries_left > 0 diff --git a/rq/utils.py b/rq/utils.py index 9cd1255..b10e262 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -349,7 +349,7 @@ def str_to_date(date_str: Optional[str]) -> Union[dt.datetime, Any]: return utcparse(date_str.decode()) -def parse_timeout(timeout: Any): +def parse_timeout(timeout: Union[int, float, str]) -> int: """Transfer all kinds of timeout format to an integer representing seconds""" if not isinstance(timeout, numbers.Integral) and timeout is not None: try: diff --git a/rq/worker.py b/rq/worker.py index 3e37425..80c0384 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -123,9 +123,9 @@ class Worker: log_result_lifespan = True # `log_job_description` is used to toggle logging an entire jobs description. log_job_description = True - # factor to increase connection_wait_time incase of continous connection failures. + # factor to increase connection_wait_time in case of continuous connection failures. exponential_backoff_factor = 2.0 - # Max Wait time (in seconds) after which exponential_backoff_factor wont be applicable. + # Max Wait time (in seconds) after which exponential_backoff_factor won't be applicable. max_connection_wait_time = 60.0 @classmethod @@ -267,7 +267,8 @@ class Worker: self.serializer = resolve_serializer(serializer) queues = [ - self.queue_class(name=q, connection=connection, job_class=self.job_class, serializer=self.serializer) + self.queue_class(name=q, connection=connection, job_class=self.job_class, + serializer=self.serializer, death_penalty_class=self.death_penalty_class,) if isinstance(q, str) else q for q in ensure_list(queues) @@ -912,6 +913,7 @@ class Worker: connection=self.connection, job_class=self.job_class, serializer=self.serializer, + death_penalty_class=self.death_penalty_class, ) if result is not None: job, queue = result @@ -1113,7 +1115,7 @@ class Worker: job.started_at = utcnow() while True: try: - with UnixSignalDeathPenalty(self.job_monitoring_interval, HorseMonitorTimeoutException): + with self.death_penalty_class(self.job_monitoring_interval, HorseMonitorTimeoutException): retpid, ret_val, rusage = self.wait_for_horse() break except HorseMonitorTimeoutException: @@ -1361,33 +1363,6 @@ class Worker: except redis.exceptions.WatchError: continue - def execute_success_callback(self, job: 'Job', result: Any): - """Executes success_callback for a job. - with timeout . - - Args: - job (Job): The Job - result (Any): The job's result. - """ - self.log.debug('Running success callbacks for %s', job.id) - job.heartbeat(utcnow(), CALLBACK_TIMEOUT) - with self.death_penalty_class(CALLBACK_TIMEOUT, JobTimeoutException, job_id=job.id): - job.success_callback(job, self.connection, result) - - def execute_failure_callback(self, job: 'Job', *exc_info): - """Executes failure_callback with timeout - - Args: - job (Job): The Job - """ - if not job.failure_callback: - return - - self.log.debug('Running failure callbacks for %s', job.id) - job.heartbeat(utcnow(), CALLBACK_TIMEOUT) - with self.death_penalty_class(CALLBACK_TIMEOUT, JobTimeoutException, job_id=job.id): - job.failure_callback(job, self.connection, *exc_info) - def perform_job(self, job: 'Job', queue: 'Queue') -> bool: """Performs the actual work of a job. Will/should only be called inside the work horse's process. @@ -1419,8 +1394,8 @@ class Worker: # to use the same exc handling when pickling fails job._result = rv - if job.success_callback: - self.execute_success_callback(job, rv) + job.heartbeat(utcnow(), job.success_callback_timeout) + job.execute_success_callback(self.death_penalty_class, rv) self.handle_job_success(job=job, queue=queue, started_job_registry=started_job_registry) except: # NOQA @@ -1430,11 +1405,11 @@ class Worker: exc_string = ''.join(traceback.format_exception(*exc_info)) try: - self.execute_failure_callback(job, *exc_info) + job.heartbeat(utcnow(), job.failure_callback_timeout) + job.execute_failure_callback(self.death_penalty_class, *exc_info) except: # noqa exc_info = sys.exc_info() exc_string = ''.join(traceback.format_exception(*exc_info)) - self.log.error('Worker %s: error while executing failure callback', self.key, exc_info=exc_info) self.handle_job_failure( job=job, exc_string=exc_string, queue=queue, started_job_registry=started_job_registry diff --git a/tests/test_cli.py b/tests/test_cli.py index 0cdca78..daa118b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,6 +5,7 @@ from uuid import uuid4 import os import json +from click import BadParameter from click.testing import CliRunner from redis import Redis @@ -14,6 +15,7 @@ from rq.cli.helpers import read_config_file, CliConfig, parse_function_arg, pars from rq.job import Job from rq.registry import FailedJobRegistry, ScheduledJobRegistry from rq.serializers import JSONSerializer +from rq.timeouts import UnixSignalDeathPenalty from rq.worker import Worker, WorkerStatus from rq.scheduler import RQScheduler @@ -118,6 +120,23 @@ class TestRQCli(RQTestCase): 'testhost.example.com', ) + def test_death_penalty_class(self): + cli_config = CliConfig() + + self.assertEqual( + UnixSignalDeathPenalty, + cli_config.death_penalty_class + ) + + cli_config = CliConfig(death_penalty_class='rq.job.Job') + self.assertEqual( + Job, + cli_config.death_penalty_class + ) + + with self.assertRaises(BadParameter): + CliConfig(death_penalty_class='rq.abcd') + def test_empty_nothing(self): """rq empty -u """ runner = CliRunner() diff --git a/tests/test_job.py b/tests/test_job.py index 9d1ceae..23bbd11 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -1,4 +1,6 @@ import json + +from rq.defaults import CALLBACK_TIMEOUT from rq.serializers import JSONSerializer import time import queue @@ -9,7 +11,7 @@ from redis import WatchError from rq.utils import as_text from rq.exceptions import DeserializationError, InvalidJobOperation, NoSuchJobError -from rq.job import Job, JobStatus, Dependency, cancel_job, get_current_job +from rq.job import Job, JobStatus, Dependency, cancel_job, get_current_job, Callback from rq.queue import Queue from rq.registry import (CanceledJobRegistry, DeferredJobRegistry, FailedJobRegistry, FinishedJobRegistry, StartedJobRegistry, @@ -209,9 +211,9 @@ class TestJob(RQTestCase): # ... and no other keys are stored self.assertEqual( - set(self.testconn.hkeys(job.key)), {b'created_at', b'data', b'description', b'ended_at', b'last_heartbeat', b'started_at', - b'worker_name', b'success_callback_name', b'failure_callback_name'} + b'worker_name', b'success_callback_name', b'failure_callback_name'}, + set(self.testconn.hkeys(job.key)) ) self.assertEqual(job.last_heartbeat, None) @@ -241,6 +243,31 @@ class TestJob(RQTestCase): self.assertEqual(stored_job.dependency.id, parent_job.id) self.assertEqual(stored_job.dependency, parent_job) + def test_persistence_of_callbacks(self): + """Storing jobs with success and/or failure callbacks.""" + job = Job.create(func=fixtures.some_calculation, + on_success=Callback(fixtures.say_hello, timeout=10), + on_failure=fixtures.say_pid) # deprecated callable + job.save() + stored_job = Job.fetch(job.id) + + self.assertEqual(fixtures.say_hello, stored_job.success_callback) + self.assertEqual(10, stored_job.success_callback_timeout) + self.assertEqual(fixtures.say_pid, stored_job.failure_callback) + self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout) + + # None(s) + job = Job.create(func=fixtures.some_calculation, + on_failure=None) + job.save() + stored_job = Job.fetch(job.id) + self.assertIsNone(stored_job.success_callback) + self.assertEqual(CALLBACK_TIMEOUT, job.success_callback_timeout) # timeout should be never none + self.assertEqual(CALLBACK_TIMEOUT, stored_job.success_callback_timeout) + self.assertIsNone(stored_job.failure_callback) + self.assertEqual(CALLBACK_TIMEOUT, job.failure_callback_timeout) # timeout should be never none + self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout) + def test_store_then_fetch(self): """Store, then fetch.""" job = Job.create(func=fixtures.some_calculation, timeout='1h', args=(3, 4), diff --git a/tests/test_registry.py b/tests/test_registry.py index 152e3fd..57584b5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -164,9 +164,9 @@ class TestRegistry(RQTestCase): self.assertNotIn(job, failed_job_registry) self.assertIn(job, self.registry) - with mock.patch.object(Job, 'failure_callback', PropertyMock()) as mocked: + with mock.patch.object(Job, 'execute_failure_callback') as mocked: self.registry.cleanup() - mocked.return_value.assert_any_call(job, self.testconn, AbandonedJobError, ANY, ANY) + mocked.assert_called_once_with(queue.death_penalty_class, AbandonedJobError, ANY, ANY) self.assertIn(job.id, failed_job_registry) self.assertNotIn(job, self.registry) job.refresh()