diff --git a/.coveragerc b/.coveragerc index bad5381..b78c524 100644 --- a/.coveragerc +++ b/.coveragerc @@ -10,4 +10,7 @@ omit = [report] exclude_lines = - if TYPE_CHECKING: \ No newline at end of file + if TYPE_CHECKING: + pragma: no cover + if __name__ == .__main__.: + \ No newline at end of file diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index c2a39bb..f4371a2 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -13,6 +13,7 @@ jobs: build: name: Python${{ matrix.python-version }}/Redis${{ matrix.redis-version }}/redis-py${{ matrix.redis-py-version }} runs-on: ubuntu-20.04 + timeout-minutes: 10 strategy: matrix: python-version: ["3.6", "3.7", "3.8.3", "3.9", "3.10", "3.11"] diff --git a/docs/docs/workers.md b/docs/docs/workers.md index d690f04..71d2ba7 100644 --- a/docs/docs/workers.md +++ b/docs/docs/workers.md @@ -36,7 +36,7 @@ You should use process managers like [Supervisor](/patterns/supervisor/) or ### Burst Mode By default, workers will start working immediately and will block and wait for -new work when they run out of work. Workers can also be started in _burst +new work when they run out of work. Workers can also be started in _burst mode_ to finish all currently available work and quit as soon as all given queues are emptied. @@ -58,6 +58,7 @@ just to scale up your workers temporarily during peak periods. In addition to `--burst`, `rq worker` also accepts these arguments: * `--url` or `-u`: URL describing Redis connection details (e.g `rq worker --url redis://:secrets@example.com:1234/9` or `rq worker --url unix:///var/run/redis/redis.sock`) +* `--burst` or `-b`: run worker in burst mode (stops after all jobs in queue have been processed). * `--path` or `-P`: multiple import paths are supported (e.g `rq worker --path foo --path bar`) * `--config` or `-c`: path to module containing RQ settings. * `--results-ttl`: job results will be kept for this number of seconds (defaults to 500). @@ -155,8 +156,6 @@ worker = Worker([queue], connection=redis, name='foo') ### Retrieving Worker Information -_Updated in version 0.10.0._ - `Worker` instances store their runtime information in Redis. Here's how to retrieve them: @@ -173,6 +172,10 @@ queue = Queue('queue_name') workers = Worker.all(queue=queue) worker = workers[0] print(worker.name) + +print('Successful jobs: ' + worker.successful_job_count) +print('Failed jobs: ' + worker.failed_job_count) +print('Total working time: '+ worker.total_working_time) # In seconds ``` Aside from `worker.name`, worker also have the following properties: @@ -230,20 +233,6 @@ w = Queue('foo', serializer=JSONSerializer) Queues will now use custom serializer -### Worker Statistics - -If you want to check the utilization of your queues, `Worker` instances -store a few useful information: - -```python -from rq.worker import Worker -worker = Worker.find_by_key('rq:worker:name') - -worker.successful_job_count # Number of jobs finished successfully -worker.failed_job_count # Number of failed jobs processed by this worker -worker.total_working_time # Amount of time spent executing jobs (in seconds) -``` - ## Better worker process title Worker process will have a better title (as displayed by system tools such as ps and top) after you installed a third-party package `setproctitle`: @@ -318,22 +307,19 @@ $ rq worker -w 'path.to.GeventWorker' ``` -## Round Robin and Random strategies for dequeuing jobs from queues +## Strategies for Dequeuing Jobs from Queues -The default worker considers the order of queues as their priority order, -and if a task is pending in a higher priority queue -it will be selected before any other in queues with lower priority (the `default` behavior). -To choose the strategy that should be used, `rq` provides the `--dequeue-strategy / -ds` option. +The default worker considers the order of queues as their priority order. +That's to say if the supplied queues are `rq worker high low`, the worker will +prioritize dequeueing jobs from `high` before `low`. To choose a different strategy, +`rq` provides the `--dequeue-strategy / -ds` option. -In certain circumstances it can be useful that a when a worker is listening to multiple queues, -say `q1`,`q2`,`q3`, the jobs are dequeued using a Round Robin strategy. That is, the 1st -dequeued job is taken from `q1`, the 2nd from `q2`, the 3rd from `q3`, the 4th -from `q1`, the 5th from `q2` and so on. To implement this strategy use `-ds round_robin` argument. +In certain circumstances, you may want to dequeue jobs in a round robin fashion. For example, +when you have `q1`,`q2`,`q3`, the 1st dequeued job is taken from `q1`, the 2nd from `q2`, +the 3rd from `q3`, the 4th from `q1`, the 5th from `q2` and so on. +To implement this strategy use `-ds round_robin` argument. -In other circumstances, it can be useful to pull jobs from the different queues randomly. -To implement this strategy use `-ds random` argument. -In fact, whenever a job is pulled from any queue with the `random` strategy, the list of queues is -shuffled, so that no queue has more priority than the other ones. +To dequeue jobs from the different queues randomly, use `-ds random` argument. Deprecation Warning: Those strategies were formely being implemented by using the custom classes `rq.worker.RoundRobinWorker` and `rq.worker.RandomWorker`. As the `--dequeue-strategy` argument allows for this option to be used with any worker, those worker classes are deprecated and will be removed from future versions. @@ -451,3 +437,31 @@ redis = Redis() # This will raise an exception if job is invalid or not currently executing send_stop_job_command(redis, job_id) ``` + +## Worker Pool + +_New in version 1.14.0._ + +
+ + Note: +

`WorkerPool` is still in beta, use at your own risk!

+
+ +WorkerPool allows you to run multiple workers in a single CLI command. + +Usage: + +```shell +rq worker-pool high default low -n 3 +``` + +Options: +* `-u` or `--url `: as defined in [redis-py's docs](https://redis.readthedocs.io/en/stable/connections.html#redis.Redis.from_url). +* `-w` or `--worker-class `: defaults to `rq.worker.Worker`. `rq.worker.SimpleWorker` is also an option. +* `-n` or `--num-workers `: defaults to 2. +* `-b` or `--burst`: run workers in burst mode (stops after all jobs in queue have been processed). +* `-l` or `--logging-level `: defaults to `INFO`. `DEBUG`, `WARNING`, `ERROR` and `CRITICAL` are supported. +* `-S` or `--serializer `: defaults to `rq.serializers.DefaultSerializer`. `rq.serializers.JSONSerializer` is also included. +* `-P` or `--path `: multiple import paths are supported (e.g `rq worker --path foo --path bar`). +* `-j` or `--job-class `: defaults to `rq.job.Job`. diff --git a/rq/cli/cli.py b/rq/cli/cli.py index 55421f6..bccde97 100755 --- a/rq/cli/cli.py +++ b/rq/cli/cli.py @@ -6,6 +6,8 @@ import os import sys import warnings +from typing import List, Type + import click from redis.exceptions import ConnectionError @@ -21,6 +23,8 @@ from rq.cli.helpers import ( parse_schedule, pass_cli_config, ) + +# from rq.cli.pool import pool from rq.contrib.legacy import cleanup_ghosts from rq.defaults import ( DEFAULT_RESULT_TTL, @@ -31,12 +35,15 @@ from rq.defaults import ( DEFAULT_MAINTENANCE_TASK_INTERVAL, ) from rq.exceptions import InvalidJobOperationError -from rq.job import JobStatus +from rq.job import Job, JobStatus from rq.logutils import blue from rq.registry import FailedJobRegistry, clean_registries +from rq.serializers import DefaultSerializer from rq.suspension import suspend as connection_suspend, resume as connection_resume, is_suspended -from rq.utils import import_attribute, get_call_string +from rq.worker import Worker +from rq.worker_pool import WorkerPool from rq.worker_registration import clean_worker_registry +from rq.utils import import_attribute, get_call_string @click.group() @@ -425,3 +432,82 @@ def enqueue( if not quiet: click.echo('Enqueued %s with job-id \'%s\'.' % (blue(function_string), job.id)) + + +@main.command() +@click.option('--burst', '-b', is_flag=True, help='Run in burst mode (quit after all work is done)') +@click.option('--logging-level', '-l', type=str, default="INFO", help='Set logging level') +@click.option('--sentry-ca-certs', envvar='RQ_SENTRY_CA_CERTS', help='Path to CRT file for Sentry DSN') +@click.option('--sentry-debug', envvar='RQ_SENTRY_DEBUG', help='Enable debug') +@click.option('--sentry-dsn', envvar='RQ_SENTRY_DSN', help='Report exceptions to this Sentry DSN') +@click.option('--verbose', '-v', is_flag=True, help='Show more output') +@click.option('--quiet', '-q', is_flag=True, help='Show less output') +@click.option('--log-format', type=str, default=DEFAULT_LOGGING_FORMAT, help='Set the format of the logs') +@click.option('--date-format', type=str, default=DEFAULT_LOGGING_DATE_FORMAT, help='Set the date format of the logs') +@click.option('--job-class', type=str, default=None, help='Dotted path to a Job class') +@click.argument('queues', nargs=-1) +@click.option('--num-workers', '-n', type=int, default=1, help='Number of workers to start') +@pass_cli_config +def worker_pool( + cli_config, + burst: bool, + logging_level, + queues, + serializer, + sentry_ca_certs, + sentry_debug, + sentry_dsn, + verbose, + quiet, + log_format, + date_format, + worker_class, + job_class, + num_workers, + **options, +): + """Starts a RQ worker pool""" + settings = read_config_file(cli_config.config) if cli_config.config else {} + # Worker specific default arguments + queue_names: List[str] = queues or settings.get('QUEUES', ['default']) + sentry_ca_certs = sentry_ca_certs or settings.get('SENTRY_CA_CERTS') + sentry_debug = sentry_debug or settings.get('SENTRY_DEBUG') + sentry_dsn = sentry_dsn or settings.get('SENTRY_DSN') + + setup_loghandlers_from_args(verbose, quiet, date_format, log_format) + + if serializer: + serializer_class: Type[DefaultSerializer] = import_attribute(serializer) + else: + serializer_class = DefaultSerializer + + if worker_class: + worker_class = import_attribute(worker_class) + else: + worker_class = Worker + + if job_class: + job_class = import_attribute(job_class) + else: + job_class = Job + + pool = WorkerPool( + queue_names, + connection=cli_config.connection, + num_workers=num_workers, + serializer=serializer_class, + worker_class=worker_class, + job_class=job_class, + ) + pool.start(burst=burst, logging_level=logging_level) + + # Should we configure Sentry? + if sentry_dsn: + sentry_opts = {"ca_certs": sentry_ca_certs, "debug": sentry_debug} + from rq.contrib.sentry import register_sentry + + register_sentry(sentry_dsn, **sentry_opts) + + +if __name__ == '__main__': + main() diff --git a/rq/queue.py b/rq/queue.py index ebfdd47..d844104 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -88,7 +88,7 @@ class Queue: """ connection = resolve_connection(connection) - def to_queue(queue_key): + def to_queue(queue_key: Union[bytes, str]): return cls.from_queue_key( as_text(queue_key), connection=connection, @@ -145,7 +145,7 @@ class Queue: default_timeout: Optional[int] = None, connection: Optional['Redis'] = None, is_async: bool = True, - job_class: Union[str, Type['Job'], None] = None, + job_class: Optional[Union[str, Type['Job']]] = None, serializer: Any = None, death_penalty_class: Type[BaseDeathPenalty] = UnixSignalDeathPenalty, **kwargs, @@ -439,7 +439,7 @@ class Queue: Returns: _type_: _description_ """ - job_id = job_or_id.id if isinstance(job_or_id, self.job_class) else job_or_id + job_id: str = job_or_id.id if isinstance(job_or_id, self.job_class) else job_or_id if pipeline is not None: return pipeline.lrem(self.key, 1, job_id) diff --git a/rq/scheduler.py b/rq/scheduler.py index ec54c1a..069181d 100644 --- a/rq/scheduler.py +++ b/rq/scheduler.py @@ -6,6 +6,7 @@ import traceback from datetime import datetime from enum import Enum from multiprocessing import Process +from typing import List, Set from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection @@ -16,7 +17,7 @@ from .logutils import setup_loghandlers from .queue import Queue from .registry import ScheduledJobRegistry from .serializers import resolve_serializer -from .utils import current_timestamp +from .utils import current_timestamp, parse_names SCHEDULER_KEY_TEMPLATE = 'rq:scheduler:%s' SCHEDULER_LOCKING_KEY_TEMPLATE = 'rq:scheduler-lock:%s' @@ -46,14 +47,10 @@ class RQScheduler: serializer=None, ): self._queue_names = set(parse_names(queues)) - self._acquired_locks = set() - self._scheduled_job_registries = [] + self._acquired_locks: Set[str] = set() + self._scheduled_job_registries: List[ScheduledJobRegistry] = [] self.lock_acquisition_time = None - ( - self._connection_class, - self._connection_pool_class, - self._connection_kwargs, - ) = parse_connection(connection) + self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection) self.serializer = resolve_serializer(serializer) self._connection = None @@ -74,10 +71,7 @@ class RQScheduler: if self._connection: return self._connection self._connection = self._connection_class( - connection_pool=ConnectionPool( - connection_class=self._connection_pool_class, - **self._connection_kwargs - ) + connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs) ) return self._connection @@ -231,14 +225,3 @@ def run(scheduler): scheduler.log.error('Scheduler [PID %s] raised an exception.\n%s', os.getpid(), traceback.format_exc()) raise scheduler.log.info('Scheduler with PID %d has stopped', os.getpid()) - - -def parse_names(queues_or_names): - """Given a list of strings or queues, returns queue names""" - names = [] - for queue_or_name in queues_or_names: - if isinstance(queue_or_name, Queue): - names.append(queue_or_name.name) - else: - names.append(str(queue_or_name)) - return names diff --git a/rq/serializers.py b/rq/serializers.py index b9b7d9c..96de3f5 100644 --- a/rq/serializers.py +++ b/rq/serializers.py @@ -1,7 +1,7 @@ from functools import partial import pickle import json -from typing import Optional, Union +from typing import Optional, Type, Union from .utils import import_attribute @@ -21,7 +21,7 @@ class JSONSerializer: return json.loads(s.decode('utf-8'), *args, **kwargs) -def resolve_serializer(serializer=None): +def resolve_serializer(serializer: Optional[Union[Type[DefaultSerializer], str]] = None) -> Type[DefaultSerializer]: """This function checks the user defined serializer for ('dumps', 'loads') methods It returns a default pickle serializer if not found else it returns a MySerializer The returned serializer objects implement ('dumps', 'loads') methods diff --git a/rq/timeouts.py b/rq/timeouts.py index 44f01f9..a8a408c 100644 --- a/rq/timeouts.py +++ b/rq/timeouts.py @@ -28,7 +28,7 @@ class HorseMonitorTimeoutException(BaseTimeoutException): class BaseDeathPenalty: """Base class to setup job timeouts.""" - def __init__(self, timeout, exception=JobTimeoutException, **kwargs): + def __init__(self, timeout, exception=BaseTimeoutException, **kwargs): self._timeout = timeout self._exception = exception diff --git a/rq/utils.py b/rq/utils.py index ca51779..db483ab 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -15,9 +15,9 @@ import datetime as dt from collections.abc import Iterable from typing import TYPE_CHECKING, Dict, List, Optional, Any, Callable, Tuple, Union - if TYPE_CHECKING: from redis import Redis + from .queue import Queue from redis.exceptions import ResponseError @@ -107,15 +107,16 @@ def import_attribute(name: str) -> Callable[..., Any]: attribute_name = '.'.join(attribute_bits) if hasattr(module, attribute_name): return getattr(module, attribute_name) - # staticmethods attribute_name = attribute_bits.pop() attribute_owner_name = '.'.join(attribute_bits) - attribute_owner = getattr(module, attribute_owner_name) + try: + attribute_owner = getattr(module, attribute_owner_name) + except: # noqa + raise ValueError('Invalid attribute name: %s' % attribute_name) if not hasattr(attribute_owner, attribute_name): raise ValueError('Invalid attribute name: %s' % name) - return getattr(attribute_owner, attribute_name) @@ -361,3 +362,16 @@ def get_call_string( args = ', '.join(arg_list) return '{0}({1})'.format(func_name, args) + + +def parse_names(queues_or_names: List[Union[str, 'Queue']]) -> List[str]: + """Given a list of strings or queues, returns queue names""" + from .queue import Queue + + names = [] + for queue_or_name in queues_or_names: + if isinstance(queue_or_name, Queue): + names.append(queue_or_name.name) + else: + names.append(str(queue_or_name)) + return names diff --git a/rq/worker.py b/rq/worker.py index 2b41aad..ade789b 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -10,10 +10,11 @@ import sys import time import traceback import warnings -from datetime import timedelta +from datetime import datetime, timedelta from enum import Enum from random import shuffle from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union +from types import FrameType from uuid import uuid4 if TYPE_CHECKING: @@ -38,7 +39,6 @@ from .command import parse_payload, PUBSUB_CHANNEL_TEMPLATE, handle_command from .connections import get_current_connection, push_connection, pop_connection from .defaults import ( - CALLBACK_TIMEOUT, DEFAULT_MAINTENANCE_TASK_INTERVAL, DEFAULT_RESULT_TTL, DEFAULT_WORKER_TTL, @@ -67,7 +67,6 @@ from .utils import ( as_text, ) from .version import VERSION -from .serializers import resolve_serializer try: @@ -133,6 +132,108 @@ class BaseWorker: # Max Wait time (in seconds) after which exponential_backoff_factor won't be applicable. max_connection_wait_time = 60.0 + def __init__( + self, + queues, + name: Optional[str] = None, + default_result_ttl=DEFAULT_RESULT_TTL, + connection: Optional['Redis'] = None, + exc_handler=None, + exception_handlers=None, + default_worker_ttl=DEFAULT_WORKER_TTL, + maintenance_interval: int = DEFAULT_MAINTENANCE_TASK_INTERVAL, + job_class: Optional[Type['Job']] = None, + queue_class: Optional[Type['Queue']] = None, + log_job_description: bool = True, + job_monitoring_interval=DEFAULT_JOB_MONITORING_INTERVAL, + disable_default_exception_handler: bool = False, + prepare_for_work: bool = True, + serializer=None, + work_horse_killed_handler: Optional[Callable[[Job, int, int, 'struct_rusage'], None]] = None, + ): # noqa + self.default_result_ttl = default_result_ttl + self.worker_ttl = default_worker_ttl + self.job_monitoring_interval = job_monitoring_interval + self.maintenance_interval = maintenance_interval + + connection = self._set_connection(connection) + self.connection = connection + self.redis_server_version = None + + self.job_class = backend_class(self, 'job_class', override=job_class) + self.queue_class = backend_class(self, 'queue_class', override=queue_class) + self.version = VERSION + self.python_version = sys.version + self.serializer = resolve_serializer(serializer) + + queues = [ + self.queue_class( + name=q, + connection=connection, + job_class=self.job_class, + serializer=self.serializer, + death_penalty_class=self.death_penalty_class, + ) + if isinstance(q, str) + else q + for q in ensure_list(queues) + ] + + self.name: str = name or uuid4().hex + self.queues = queues + self.validate_queues() + self._ordered_queues = self.queues[:] + self._exc_handlers: List[Callable] = [] + self._work_horse_killed_handler = work_horse_killed_handler + self._shutdown_requested_date: Optional[datetime] = None + + self._state: str = 'starting' + self._is_horse: bool = False + self._horse_pid: int = 0 + self._stop_requested: bool = False + self._stopped_job_id = None + + self.log = logger + self.log_job_description = log_job_description + self.last_cleaned_at = None + self.successful_job_count: int = 0 + self.failed_job_count: int = 0 + self.total_working_time: int = 0 + self.current_job_working_time: float = 0 + self.birth_date = None + self.scheduler: Optional[RQScheduler] = None + self.pubsub = None + self.pubsub_thread = None + self._dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT + + self.disable_default_exception_handler = disable_default_exception_handler + + if prepare_for_work: + self.hostname: Optional[str] = socket.gethostname() + self.pid: Optional[int] = os.getpid() + try: + connection.client_setname(self.name) + except redis.exceptions.ResponseError: + warnings.warn('CLIENT SETNAME command not supported, setting ip_address to unknown', Warning) + self.ip_address = 'unknown' + else: + client_adresses = [client['addr'] for client in connection.client_list() if client['name'] == self.name] + if len(client_adresses) > 0: + self.ip_address = client_adresses[0] + else: + warnings.warn('CLIENT LIST command not supported, setting ip_address to unknown', Warning) + self.ip_address = 'unknown' + else: + self.hostname = None + self.pid = None + self.ip_address = 'unknown' + + if isinstance(exception_handlers, (list, tuple)): + for handler in exception_handlers: + self.push_exc_handler(handler) + elif exception_handlers is not None: + self.push_exc_handler(exception_handlers) + @classmethod def all( cls, @@ -185,50 +286,352 @@ class BaseWorker: Returns: length (int): The queue length. """ - return len(worker_registration.get_keys(queue=queue, connection=connection)) + return len(worker_registration.get_keys(queue=queue, connection=connection)) + + @property + def should_run_maintenance_tasks(self): + """Maintenance tasks should run on first startup or every 10 minutes.""" + if self.last_cleaned_at is None: + return True + if (utcnow() - self.last_cleaned_at) > timedelta(seconds=self.maintenance_interval): + return True + return False + + def get_redis_server_version(self): + """Return Redis server version of connection""" + if not self.redis_server_version: + self.redis_server_version = get_version(self.connection) + return self.redis_server_version + + def validate_queues(self): + """Sanity check for the given queues.""" + for queue in self.queues: + if not isinstance(queue, self.queue_class): + raise TypeError('{0} is not of type {1} or string types'.format(queue, self.queue_class)) + + def queue_names(self) -> List[str]: + """Returns the queue names of this worker's queues. + + Returns: + List[str]: The queue names. + """ + return [queue.name for queue in self.queues] + + def queue_keys(self) -> List[str]: + """Returns the Redis keys representing this worker's queues. + + Returns: + List[str]: The list of strings with queues keys + """ + return [queue.key for queue in self.queues] + + @property + def key(self): + """Returns the worker's Redis hash key.""" + return self.redis_worker_namespace_prefix + self.name + + @property + def pubsub_channel_name(self): + """Returns the worker's Redis hash key.""" + return PUBSUB_CHANNEL_TEMPLATE % self.name + + @property + def supports_redis_streams(self) -> bool: + """Only supported by Redis server >= 5.0 is required.""" + return self.get_redis_server_version() >= (5, 0, 0) + + def _install_signal_handlers(self): + """Installs signal handlers for handling SIGINT and SIGTERM gracefully.""" + signal.signal(signal.SIGINT, self.request_stop) + signal.signal(signal.SIGTERM, self.request_stop) + + def work( + self, + burst: bool = False, + logging_level: str = "INFO", + date_format: str = DEFAULT_LOGGING_DATE_FORMAT, + log_format: str = DEFAULT_LOGGING_FORMAT, + max_jobs: Optional[int] = None, + max_idle_time: Optional[int] = None, + with_scheduler: bool = False, + dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT, + ) -> bool: + """Starts the work loop. + + Pops and performs all jobs on the current list of queues. When all + queues are empty, block and wait for new jobs to arrive on any of the + queues, unless `burst` mode is enabled. + If `max_idle_time` is provided, worker will die when it's idle for more than the provided value. + + The return value indicates whether any jobs were processed. + + Args: + burst (bool, optional): Whether to work on burst mode. Defaults to False. + logging_level (str, optional): Logging level to use. Defaults to "INFO". + date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. + log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. + max_jobs (Optional[int], optional): Max number of jobs. Defaults to None. + max_idle_time (Optional[int], optional): Max seconds for worker to be idle. Defaults to None. + with_scheduler (bool, optional): Whether to run the scheduler in a separate process. Defaults to False. + dequeue_strategy (DequeueStrategy, optional): Which strategy to use to dequeue jobs. Defaults to DequeueStrategy.DEFAULT + + Returns: + worked (bool): Will return True if any job was processed, False otherwise. + """ + self.bootstrap(logging_level, date_format, log_format) + self._dequeue_strategy = dequeue_strategy + completed_jobs = 0 + if with_scheduler: + self._start_scheduler(burst, logging_level, date_format, log_format) + + self._install_signal_handlers() + try: + while True: + try: + self.check_for_suspension(burst) + + if self.should_run_maintenance_tasks: + self.run_maintenance_tasks() + + if self._stop_requested: + self.log.info('Worker %s: stopping on request', self.key) + break + + timeout = None if burst else self.dequeue_timeout + result = self.dequeue_job_and_maintain_ttl(timeout, max_idle_time) + if result is None: + if burst: + self.log.info('Worker %s: done, quitting', self.key) + elif max_idle_time is not None: + self.log.info('Worker %s: idle for %d seconds, quitting', self.key, max_idle_time) + break + + job, queue = result + self.execute_job(job, queue) + self.heartbeat() + + completed_jobs += 1 + if max_jobs is not None: + if completed_jobs >= max_jobs: + self.log.info('Worker %s: finished executing %d jobs, quitting', self.key, completed_jobs) + break + + except redis.exceptions.TimeoutError: + self.log.error('Worker %s: Redis connection timeout, quitting...', self.key) + break + + except StopRequested: + break + + except SystemExit: + # Cold shutdown detected + raise + + except: # noqa + self.log.error('Worker %s: found an unhandled exception, quitting...', self.key, exc_info=True) + break + finally: + self.teardown() + return bool(completed_jobs) + + def _start_scheduler( + self, + burst: bool = False, + logging_level: str = "INFO", + date_format: str = DEFAULT_LOGGING_DATE_FORMAT, + log_format: str = DEFAULT_LOGGING_FORMAT, + ): + """Starts the scheduler process. + This is specifically designed to be run by the worker when running the `work()` method. + Instanciates the RQScheduler and tries to acquire a lock. + If the lock is acquired, start scheduler. + If worker is on burst mode just enqueues scheduled jobs and quits, + otherwise, starts the scheduler in a separate process. + + Args: + burst (bool, optional): Whether to work on burst mode. Defaults to False. + logging_level (str, optional): Logging level to use. Defaults to "INFO". + date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. + log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. + """ + self.scheduler = RQScheduler( + self.queues, + connection=self.connection, + logging_level=logging_level, + date_format=date_format, + log_format=log_format, + serializer=self.serializer, + ) + self.scheduler.acquire_locks() + if self.scheduler.acquired_locks: + if burst: + self.scheduler.enqueue_scheduled_jobs() + self.scheduler.release_locks() + else: + self.scheduler.start() + + def bootstrap( + self, + logging_level: str = "INFO", + date_format: str = DEFAULT_LOGGING_DATE_FORMAT, + log_format: str = DEFAULT_LOGGING_FORMAT, + ): + """Bootstraps the worker. + Runs the basic tasks that should run when the worker actually starts working. + Used so that new workers can focus on the work loop implementation rather + than the full bootstraping process. + + Args: + logging_level (str, optional): Logging level to use. Defaults to "INFO". + date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. + log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. + """ + setup_loghandlers(logging_level, date_format, log_format) + self.register_birth() + self.log.info('Worker %s started with PID %d, version %s', self.key, os.getpid(), VERSION) + self.subscribe() + self.set_state(WorkerStatus.STARTED) + qnames = self.queue_names() + self.log.info('*** Listening on %s...', green(', '.join(qnames))) + + def check_for_suspension(self, burst: bool): + """Check to see if workers have been suspended by `rq suspend`""" + before_state = None + notified = False + + while not self._stop_requested and is_suspended(self.connection, self): + if burst: + self.log.info('Suspended in burst mode, exiting') + self.log.info('Note: There could still be unfinished jobs on the queue') + raise StopRequested + + if not notified: + self.log.info('Worker suspended, run `rq resume` to resume') + before_state = self.get_state() + self.set_state(WorkerStatus.SUSPENDED) + notified = True + time.sleep(1) + + if before_state: + self.set_state(before_state) + + def run_maintenance_tasks(self): + """ + Runs periodic maintenance tasks, these include: + 1. Check if scheduler should be started. This check should not be run + on first run since worker.work() already calls + `scheduler.enqueue_scheduled_jobs()` on startup. + 2. Cleaning registries + + No need to try to start scheduler on first run + """ + if self.last_cleaned_at: + if self.scheduler and (not self.scheduler._process or not self.scheduler._process.is_alive()): + self.scheduler.acquire_locks(auto_start=True) + self.clean_registries() + + def subscribe(self): + """Subscribe to this worker's channel""" + self.log.info('Subscribing to channel %s', self.pubsub_channel_name) + self.pubsub = self.connection.pubsub() + self.pubsub.subscribe(**{self.pubsub_channel_name: self.handle_payload}) + self.pubsub_thread = self.pubsub.run_in_thread(sleep_time=0.2, daemon=True) + + def unsubscribe(self): + """Unsubscribe from pubsub channel""" + if self.pubsub_thread: + self.log.info('Unsubscribing from channel %s', self.pubsub_channel_name) + self.pubsub_thread.stop() + self.pubsub_thread.join() + self.pubsub.unsubscribe() + self.pubsub.close() + + def dequeue_job_and_maintain_ttl( + self, timeout: Optional[int], max_idle_time: Optional[int] = None + ) -> Tuple['Job', 'Queue']: + """Dequeues a job while maintaining the TTL. + + Returns: + result (Tuple[Job, Queue]): A tuple with the job and the queue. + """ + result = None + qnames = ','.join(self.queue_names()) - def get_redis_server_version(self): - """Return Redis server version of connection""" - if not self.redis_server_version: - self.redis_server_version = get_version(self.connection) - return self.redis_server_version + self.set_state(WorkerStatus.IDLE) + self.procline('Listening on ' + qnames) + self.log.debug('*** Listening on %s...', green(qnames)) + connection_wait_time = 1.0 + idle_since = utcnow() + idle_time_left = max_idle_time + while True: + try: + self.heartbeat() - def validate_queues(self): - """Sanity check for the given queues.""" - for queue in self.queues: - if not isinstance(queue, self.queue_class): - raise TypeError('{0} is not of type {1} or string types'.format(queue, self.queue_class)) + if self.should_run_maintenance_tasks: + self.run_maintenance_tasks() - def queue_names(self) -> List[str]: - """Returns the queue names of this worker's queues. + if timeout is not None and idle_time_left is not None: + timeout = min(timeout, idle_time_left) - Returns: - List[str]: The queue names. - """ - return [queue.name for queue in self.queues] + self.log.debug('Dequeueing jobs on queues %s and timeout %s', green(qnames), timeout) + result = self.queue_class.dequeue_any( + self._ordered_queues, + timeout, + connection=self.connection, + job_class=self.job_class, + serializer=self.serializer, + death_penalty_class=self.death_penalty_class, + ) + if result is not None: + job, queue = result + self.reorder_queues(reference_queue=queue) + self.log.debug('Dequeued job %s from %s', blue(job.id), green(queue.name)) + job.redis_server_version = self.get_redis_server_version() + if self.log_job_description: + self.log.info('%s: %s (%s)', green(queue.name), blue(job.description), job.id) + else: + self.log.info('%s: %s', green(queue.name), job.id) - def queue_keys(self) -> List[str]: - """Returns the Redis keys representing this worker's queues. + break + except DequeueTimeout: + if max_idle_time is not None: + idle_for = (utcnow() - idle_since).total_seconds() + idle_time_left = math.ceil(max_idle_time - idle_for) + if idle_time_left <= 0: + break + except redis.exceptions.ConnectionError as conn_err: + self.log.error( + 'Could not connect to Redis instance: %s Retrying in %d seconds...', conn_err, connection_wait_time + ) + time.sleep(connection_wait_time) + connection_wait_time *= self.exponential_backoff_factor + connection_wait_time = min(connection_wait_time, self.max_connection_wait_time) + else: + connection_wait_time = 1.0 - Returns: - List[str]: The list of strings with queues keys - """ - return [queue.key for queue in self.queues] + self.heartbeat() + return result - @property - def key(self): - """Returns the worker's Redis hash key.""" - return self.redis_worker_namespace_prefix + self.name + def heartbeat(self, timeout: Optional[int] = None, pipeline: Optional['Pipeline'] = None): + """Specifies a new worker timeout, typically by extending the + expiration time of the worker, effectively making this a "heartbeat" + to not expire the worker until the timeout passes. - @property - def pubsub_channel_name(self): - """Returns the worker's Redis hash key.""" - return PUBSUB_CHANNEL_TEMPLATE % self.name + The next heartbeat should come before this time, or the worker will + die (at least from the monitoring dashboards). - @property - def supports_redis_streams(self) -> bool: - """Only supported by Redis server >= 5.0 is required.""" - return self.get_redis_server_version() >= (5, 0, 0) + If no timeout is given, the worker_ttl will be used to update + the expiration time of the worker. + + Args: + timeout (Optional[int]): Timeout + pipeline (Optional[Redis]): A Redis pipeline + """ + timeout = timeout or self.worker_ttl + 60 + connection: Union[Redis, 'Pipeline'] = pipeline if pipeline is not None else self.connection + connection.expire(self.key, timeout) + connection.hset(self.key, 'last_heartbeat', utcformat(utcnow())) + self.log.debug('Sent heartbeat to prevent worker timeout. ' 'Next one should arrive in %s seconds.', timeout) class Worker(BaseWorker): @@ -249,7 +652,7 @@ class Worker(BaseWorker): worker_key (str): The worker key connection (Optional[Redis], optional): Redis connection. Defaults to None. job_class (Optional[Type[Job]], optional): The job class if custom class is being used. Defaults to None. - queue_class (Optional[Type[Queue]], optional): The queue class if a custom class is being used. Defaults to None. + queue_class (Optional[Type[Queue]]): The queue class if a custom class is being used. Defaults to None. serializer (Any, optional): The serializer to use. Defaults to None. Raises: @@ -282,107 +685,6 @@ class Worker(BaseWorker): worker.refresh() return worker - def __init__( - self, - queues, - name: Optional[str] = None, - default_result_ttl=DEFAULT_RESULT_TTL, - connection: Optional['Redis'] = None, - exc_handler=None, - exception_handlers=None, - default_worker_ttl=DEFAULT_WORKER_TTL, - maintenance_interval: int = DEFAULT_MAINTENANCE_TASK_INTERVAL, - job_class: Type['Job'] = None, - queue_class=None, - log_job_description: bool = True, - job_monitoring_interval=DEFAULT_JOB_MONITORING_INTERVAL, - disable_default_exception_handler: bool = False, - prepare_for_work: bool = True, - serializer=None, - work_horse_killed_handler: Optional[Callable[[Job, int, int, 'struct_rusage'], None]] = None, - ): # noqa - self.default_result_ttl = default_result_ttl - self.worker_ttl = default_worker_ttl - self.job_monitoring_interval = job_monitoring_interval - self.maintenance_interval = maintenance_interval - - connection = self._set_connection(connection) - self.connection = connection - self.redis_server_version = None - - self.job_class = backend_class(self, 'job_class', override=job_class) - self.queue_class = backend_class(self, 'queue_class', override=queue_class) - self.version = VERSION - self.python_version = sys.version - self.serializer = resolve_serializer(serializer) - - queues = [ - self.queue_class( - name=q, - connection=connection, - job_class=self.job_class, - serializer=self.serializer, - death_penalty_class=self.death_penalty_class, - ) - if isinstance(q, str) - else q - for q in ensure_list(queues) - ] - - self.name: str = name or uuid4().hex - self.queues = queues - self.validate_queues() - self._ordered_queues = self.queues[:] - self._exc_handlers: List[Callable] = [] - self._work_horse_killed_handler = work_horse_killed_handler - - self._state: str = 'starting' - self._is_horse: bool = False - self._horse_pid: int = 0 - self._stop_requested: bool = False - self._stopped_job_id = None - - self.log = logger - self.log_job_description = log_job_description - self.last_cleaned_at = None - self.successful_job_count: int = 0 - self.failed_job_count: int = 0 - self.total_working_time: int = 0 - self.current_job_working_time: float = 0 - self.birth_date = None - self.scheduler: Optional[RQScheduler] = None - self.pubsub = None - self.pubsub_thread = None - self._dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT - - self.disable_default_exception_handler = disable_default_exception_handler - - if prepare_for_work: - self.hostname: Optional[str] = socket.gethostname() - self.pid: Optional[int] = os.getpid() - try: - connection.client_setname(self.name) - except redis.exceptions.ResponseError: - warnings.warn('CLIENT SETNAME command not supported, setting ip_address to unknown', Warning) - self.ip_address = 'unknown' - else: - client_adresses = [client['addr'] for client in connection.client_list() if client['name'] == self.name] - if len(client_adresses) > 0: - self.ip_address = client_adresses[0] - else: - warnings.warn('CLIENT LIST command not supported, setting ip_address to unknown', Warning) - self.ip_address = 'unknown' - else: - self.hostname = None - self.pid = None - self.ip_address = 'unknown' - - if isinstance(exception_handlers, (list, tuple)): - for handler in exception_handlers: - self.push_exc_handler(handler) - elif exception_handlers is not None: - self.push_exc_handler(exception_handlers) - def _set_connection(self, connection: Optional['Redis']) -> 'Redis': """Configures the Redis connection to have a socket timeout. This should timouet the connection in case any specific command hangs at any given time (eg. BLPOP). @@ -473,7 +775,7 @@ class Worker(BaseWorker): def set_shutdown_requested_date(self): """Sets the date on which the worker received a (warm) shutdown request""" - self.connection.hset(self.key, 'shutdown_requested_date', utcformat(utcnow())) + self.connection.hset(self.key, 'shutdown_requested_date', utcformat(self._shutdown_requested_date)) @property def shutdown_requested_date(self): @@ -566,11 +868,6 @@ class Worker(BaseWorker): return None return self.job_class.fetch(job_id, self.connection, self.serializer) - def _install_signal_handlers(self): - """Installs signal handlers for handling SIGINT and SIGTERM gracefully.""" - signal.signal(signal.SIGINT, self.request_stop) - signal.signal(signal.SIGTERM, self.request_stop) - def kill_horse(self, sig: signal.Signals = SIGKILL): """Kill the horse but catch "No such process" error has the horse could already be dead. @@ -596,7 +893,7 @@ class Worker(BaseWorker): pid, stat, rusage = os.wait4(self.horse_pid, 0) return pid, stat, rusage - def request_force_stop(self, signum, frame): + def request_force_stop(self, signum: int, frame: Optional[FrameType]): """Terminates the application (cold shutdown). Args: @@ -606,6 +903,14 @@ class Worker(BaseWorker): Raises: SystemExit: SystemExit """ + # When worker is run through a worker pool, it may receive duplicate signals + # One is sent by the pool when it calls `pool.stop_worker()` and another is sent by the OS + # when user hits Ctrl+C. In this case if we receive the second signal within 1 second, + # we ignore it. + if (utcnow() - self._shutdown_requested_date) < timedelta(seconds=1): # type: ignore + self.log.debug('Shutdown signal ignored, received twice in less than 1 second') + return + self.log.warning('Cold shut down') # Take down the horse with the worker @@ -624,6 +929,7 @@ class Worker(BaseWorker): frame (Any): Frame """ self.log.debug('Got signal %s', signal_name(signum)) + self._shutdown_requested_date = utcnow() signal.signal(signal.SIGINT, self.request_force_stop) signal.signal(signal.SIGTERM, self.request_force_stop) @@ -642,65 +948,13 @@ class Worker(BaseWorker): self.log.debug('Stopping after current horse is finished. ' 'Press Ctrl+C again for a cold shutdown.') if self.scheduler: self.stop_scheduler() - else: - if self.scheduler: - self.stop_scheduler() - raise StopRequested() - - def handle_warm_shutdown_request(self): - self.log.info('Warm shut down requested') - - def check_for_suspension(self, burst: bool): - """Check to see if workers have been suspended by `rq suspend`""" - before_state = None - notified = False - - while not self._stop_requested and is_suspended(self.connection, self): - if burst: - self.log.info('Suspended in burst mode, exiting') - self.log.info('Note: There could still be unfinished jobs on the queue') - raise StopRequested - - if not notified: - self.log.info('Worker suspended, run `rq resume` to resume') - before_state = self.get_state() - self.set_state(WorkerStatus.SUSPENDED) - notified = True - time.sleep(1) - - if before_state: - self.set_state(before_state) - - def run_maintenance_tasks(self): - """ - Runs periodic maintenance tasks, these include: - 1. Check if scheduler should be started. This check should not be run - on first run since worker.work() already calls - `scheduler.enqueue_scheduled_jobs()` on startup. - 2. Cleaning registries - - No need to try to start scheduler on first run - """ - if self.last_cleaned_at: - if self.scheduler and (not self.scheduler._process or not self.scheduler._process.is_alive()): - self.scheduler.acquire_locks(auto_start=True) - self.clean_registries() - - def subscribe(self): - """Subscribe to this worker's channel""" - self.log.info('Subscribing to channel %s', self.pubsub_channel_name) - self.pubsub = self.connection.pubsub() - self.pubsub.subscribe(**{self.pubsub_channel_name: self.handle_payload}) - self.pubsub_thread = self.pubsub.run_in_thread(sleep_time=0.2, daemon=True) + else: + if self.scheduler: + self.stop_scheduler() + raise StopRequested() - def unsubscribe(self): - """Unsubscribe from pubsub channel""" - if self.pubsub_thread: - self.log.info('Unsubscribing from channel %s', self.pubsub_channel_name) - self.pubsub_thread.stop() - self.pubsub_thread.join() - self.pubsub.unsubscribe() - self.pubsub.close() + def handle_warm_shutdown_request(self): + self.log.info('Worker %s [PID %d]: warm shut down requested', self.name, self.pid) def reorder_queues(self, reference_queue: 'Queue'): """Reorder the queues according to the strategy. @@ -727,155 +981,6 @@ class Worker(BaseWorker): shuffle(self._ordered_queues) return - def bootstrap( - self, - logging_level: str = "INFO", - date_format: str = DEFAULT_LOGGING_DATE_FORMAT, - log_format: str = DEFAULT_LOGGING_FORMAT, - ): - """Bootstraps the worker. - Runs the basic tasks that should run when the worker actually starts working. - Used so that new workers can focus on the work loop implementation rather - than the full bootstraping process. - - Args: - logging_level (str, optional): Logging level to use. Defaults to "INFO". - date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. - log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. - """ - setup_loghandlers(logging_level, date_format, log_format) - self.register_birth() - self.log.info('Worker %s: started, version %s', self.key, VERSION) - self.subscribe() - self.set_state(WorkerStatus.STARTED) - qnames = self.queue_names() - self.log.info('*** Listening on %s...', green(', '.join(qnames))) - - def _start_scheduler( - self, - burst: bool = False, - logging_level: str = "INFO", - date_format: str = DEFAULT_LOGGING_DATE_FORMAT, - log_format: str = DEFAULT_LOGGING_FORMAT, - ): - """Starts the scheduler process. - This is specifically designed to be run by the worker when running the `work()` method. - Instanciates the RQScheduler and tries to acquire a lock. - If the lock is acquired, start scheduler. - If worker is on burst mode just enqueues scheduled jobs and quits, - otherwise, starts the scheduler in a separate process. - - Args: - burst (bool, optional): Whether to work on burst mode. Defaults to False. - logging_level (str, optional): Logging level to use. Defaults to "INFO". - date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. - log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. - """ - self.scheduler = RQScheduler( - self.queues, - connection=self.connection, - logging_level=logging_level, - date_format=date_format, - log_format=log_format, - serializer=self.serializer, - ) - self.scheduler.acquire_locks() - if self.scheduler.acquired_locks: - if burst: - self.scheduler.enqueue_scheduled_jobs() - self.scheduler.release_locks() - else: - self.scheduler.start() - - def work( - self, - burst: bool = False, - logging_level: str = "INFO", - date_format: str = DEFAULT_LOGGING_DATE_FORMAT, - log_format: str = DEFAULT_LOGGING_FORMAT, - max_jobs: Optional[int] = None, - max_idle_time: Optional[int] = None, - with_scheduler: bool = False, - dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT, - ) -> bool: - """Starts the work loop. - - Pops and performs all jobs on the current list of queues. When all - queues are empty, block and wait for new jobs to arrive on any of the - queues, unless `burst` mode is enabled. - If `max_idle_time` is provided, worker will die when it's idle for more than the provided value. - - The return value indicates whether any jobs were processed. - - Args: - burst (bool, optional): Whether to work on burst mode. Defaults to False. - logging_level (str, optional): Logging level to use. Defaults to "INFO". - date_format (str, optional): Date Format. Defaults to DEFAULT_LOGGING_DATE_FORMAT. - log_format (str, optional): Log Format. Defaults to DEFAULT_LOGGING_FORMAT. - max_jobs (Optional[int], optional): Max number of jobs. Defaults to None. - max_idle_time (Optional[int], optional): Max seconds for worker to be idle. Defaults to None. - with_scheduler (bool, optional): Whether to run the scheduler in a separate process. Defaults to False. - dequeue_strategy (DequeueStrategy, optional): Which strategy to use to dequeue jobs. Defaults to DequeueStrategy.DEFAULT - - Returns: - worked (bool): Will return True if any job was processed, False otherwise. - """ - self.bootstrap(logging_level, date_format, log_format) - self._dequeue_strategy = dequeue_strategy - completed_jobs = 0 - if with_scheduler: - self._start_scheduler(burst, logging_level, date_format, log_format) - - self._install_signal_handlers() - try: - while True: - try: - self.check_for_suspension(burst) - - if self.should_run_maintenance_tasks: - self.run_maintenance_tasks() - - if self._stop_requested: - self.log.info('Worker %s: stopping on request', self.key) - break - - timeout = None if burst else self.dequeue_timeout - result = self.dequeue_job_and_maintain_ttl(timeout, max_idle_time) - if result is None: - if burst: - self.log.info('Worker %s: done, quitting', self.key) - elif max_idle_time is not None: - self.log.info('Worker %s: idle for %d seconds, quitting', self.key, max_idle_time) - break - - job, queue = result - self.execute_job(job, queue) - self.heartbeat() - - completed_jobs += 1 - if max_jobs is not None: - if completed_jobs >= max_jobs: - self.log.info('Worker %s: finished executing %d jobs, quitting', self.key, completed_jobs) - break - - except redis.exceptions.TimeoutError: - self.log.error('Worker %s: Redis connection timeout, quitting...', self.key) - break - - except StopRequested: - break - - except SystemExit: - # Cold shutdown detected - raise - - except: # noqa - self.log.error('Worker %s: found an unhandled exception, quitting...', self.key, exc_info=True) - break - finally: - self.teardown() - return bool(completed_jobs) - def teardown(self): if not self.is_horse: if self.scheduler: @@ -896,95 +1001,6 @@ class Worker(BaseWorker): pass self.scheduler._process.join() - def dequeue_job_and_maintain_ttl( - self, timeout: Optional[int], max_idle_time: Optional[int] = None - ) -> Tuple['Job', 'Queue']: - """Dequeues a job while maintaining the TTL. - - Returns: - result (Tuple[Job, Queue]): A tuple with the job and the queue. - """ - result = None - qnames = ','.join(self.queue_names()) - - self.set_state(WorkerStatus.IDLE) - self.procline('Listening on ' + qnames) - self.log.debug('*** Listening on %s...', green(qnames)) - connection_wait_time = 1.0 - idle_since = utcnow() - idle_time_left = max_idle_time - while True: - try: - self.heartbeat() - - if self.should_run_maintenance_tasks: - self.run_maintenance_tasks() - - if timeout is not None and idle_time_left is not None: - timeout = min(timeout, idle_time_left) - - self.log.debug('Dequeueing jobs on queues %s and timeout %d', green(qnames), timeout) - result = self.queue_class.dequeue_any( - self._ordered_queues, - timeout, - connection=self.connection, - job_class=self.job_class, - serializer=self.serializer, - death_penalty_class=self.death_penalty_class, - ) - if result is not None: - job, queue = result - self.reorder_queues(reference_queue=queue) - self.log.debug('Dequeued job %s from %s', blue(job.id), green(queue.name)) - job.redis_server_version = self.get_redis_server_version() - if self.log_job_description: - self.log.info('%s: %s (%s)', green(queue.name), blue(job.description), job.id) - else: - self.log.info('%s: %s', green(queue.name), job.id) - - break - except DequeueTimeout: - if max_idle_time is not None: - idle_for = (utcnow() - idle_since).total_seconds() - idle_time_left = math.ceil(max_idle_time - idle_for) - if idle_time_left <= 0: - break - except redis.exceptions.ConnectionError as conn_err: - self.log.error( - 'Could not connect to Redis instance: %s Retrying in %d seconds...', conn_err, connection_wait_time - ) - time.sleep(connection_wait_time) - connection_wait_time *= self.exponential_backoff_factor - connection_wait_time = min(connection_wait_time, self.max_connection_wait_time) - else: - connection_wait_time = 1.0 - - self.heartbeat() - return result - - def heartbeat(self, timeout: Optional[int] = None, pipeline: Optional['Pipeline'] = None): - """Specifies a new worker timeout, typically by extending the - expiration time of the worker, effectively making this a "heartbeat" - to not expire the worker until the timeout passes. - - The next heartbeat should come before this time, or the worker will - die (at least from the monitoring dashboards). - - If no timeout is given, the worker_ttl will be used to update - the expiration time of the worker. - - Args: - timeout (Optional[int]): Timeout - pipeline (Optional[Redis]): A Redis pipeline - """ - timeout = timeout or self.worker_ttl + 60 - connection = pipeline if pipeline is not None else self.connection - connection.expire(self.key, timeout) - connection.hset(self.key, 'last_heartbeat', utcformat(utcnow())) - self.log.debug( - 'Sent heartbeat to prevent worker timeout. ' 'Next one should arrive within %s seconds.', timeout - ) - def refresh(self): """Refreshes the worker data. It will get the data from the datastore and update the Worker's attributes @@ -1104,7 +1120,7 @@ class Worker(BaseWorker): self._horse_pid = child_pid self.procline('Forked {0} at {1}'.format(child_pid, time.time())) - def get_heartbeat_ttl(self, job: 'Job') -> Union[float, int]: + def get_heartbeat_ttl(self, job: 'Job') -> int: """Get's the TTL for the next heartbeat. Args: @@ -1521,15 +1537,6 @@ class Worker(BaseWorker): worker_registration.clean_worker_registry(queue) self.last_cleaned_at = utcnow() - @property - def should_run_maintenance_tasks(self): - """Maintenance tasks should run on first startup or every 10 minutes.""" - if self.last_cleaned_at is None: - return True - if (utcnow() - self.last_cleaned_at) > timedelta(seconds=self.maintenance_interval): - return True - return False - def handle_payload(self, message): """Handle external commands""" self.log.debug('Received message: %s', message) @@ -1544,7 +1551,7 @@ class SimpleWorker(Worker): self.perform_job(job, queue) self.set_state(WorkerStatus.IDLE) - def get_heartbeat_ttl(self, job: 'Job') -> Union[float, int]: + def get_heartbeat_ttl(self, job: 'Job') -> int: """-1" means that jobs never timeout. In this case, we should _not_ do -1 + 60 = 59. We should just stick to DEFAULT_WORKER_TTL. @@ -1552,7 +1559,7 @@ class SimpleWorker(Worker): job (Job): The Job Returns: - ttl (float | int): TTL + ttl (int): TTL """ if job.timeout == -1: return DEFAULT_WORKER_TTL diff --git a/rq/worker_pool.py b/rq/worker_pool.py new file mode 100644 index 0000000..4bd21bb --- /dev/null +++ b/rq/worker_pool.py @@ -0,0 +1,250 @@ +import contextlib +import errno +import logging +import os +import signal +import time + +from enum import Enum +from multiprocessing import Process +from typing import Dict, List, NamedTuple, Optional, Set, Type, Union +from uuid import uuid4 + +from redis import Redis +from redis import SSLConnection, UnixDomainSocketConnection +from rq.serializers import DefaultSerializer + +from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty + +from .connections import parse_connection +from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT +from .job import Job +from .logutils import setup_loghandlers +from .queue import Queue +from .utils import parse_names +from .worker import BaseWorker, Worker + + +class WorkerData(NamedTuple): + name: str + pid: int + process: Process + + +class WorkerPool: + class Status(Enum): + IDLE = 1 + STARTED = 2 + STOPPED = 3 + + def __init__( + self, + queues: List[Union[str, Queue]], + connection: Redis, + num_workers: int = 1, + worker_class: Type[BaseWorker] = Worker, + serializer: Type[DefaultSerializer] = DefaultSerializer, + job_class: Type[Job] = Job, + *args, + **kwargs, + ): + self.num_workers: int = num_workers + self._workers: List[Worker] = [] + setup_loghandlers('INFO', DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__) + self.log: logging.Logger = logging.getLogger(__name__) + # self.log: logging.Logger = logger + self._queue_names: List[str] = parse_names(queues) + self.connection = connection + self.name: str = uuid4().hex + self._burst: bool = True + self._sleep: int = 0 + self.status: self.Status = self.Status.IDLE # type: ignore + self.worker_class: Type[BaseWorker] = worker_class + self.serializer: Type[DefaultSerializer] = serializer + self.job_class: Type[Job] = job_class + + # A dictionary of WorkerData keyed by worker name + self.worker_dict: Dict[str, WorkerData] = {} + self._connection_class, _, self._connection_kwargs = parse_connection(connection) + + @property + def queues(self) -> List[Queue]: + """Returns a list of Queue objects""" + return [Queue(name, connection=self.connection) for name in self._queue_names] + + @property + def number_of_active_workers(self) -> int: + """Returns a list of Queue objects""" + return len(self.worker_dict) + + def _install_signal_handlers(self): + """Installs signal handlers for handling SIGINT and SIGTERM + gracefully. + """ + signal.signal(signal.SIGINT, self.request_stop) + signal.signal(signal.SIGTERM, self.request_stop) + + def request_stop(self, signum=None, frame=None): + """Toggle self._stop_requested that's checked on every loop""" + self.log.info('Received SIGINT/SIGTERM, shutting down...') + self.status = self.Status.STOPPED + self.stop_workers() + + def all_workers_have_stopped(self) -> bool: + """Returns True if all workers have stopped.""" + self.reap_workers() + # `bool(self.worker_dict)` sometimes returns True even if the dict is empty + return self.number_of_active_workers == 0 + + def reap_workers(self): + """Removes dead workers from worker_dict""" + self.log.debug('Reaping dead workers') + worker_datas = list(self.worker_dict.values()) + + for data in worker_datas: + data.process.join(0.1) + if data.process.is_alive(): + self.log.debug('Worker %s with pid %d is alive', data.name, data.pid) + else: + self.handle_dead_worker(data) + continue + + # I'm still not sure why this is sometimes needed, temporarily commenting + # this out until I can figure it out. + # with contextlib.suppress(HorseMonitorTimeoutException): + # with UnixSignalDeathPenalty(1, HorseMonitorTimeoutException): + # try: + # # If wait4 returns, the process is dead + # os.wait4(data.process.pid, 0) # type: ignore + # self.handle_dead_worker(data) + # except ChildProcessError: + # # Process is dead + # self.handle_dead_worker(data) + # continue + + def handle_dead_worker(self, worker_data: WorkerData): + """ + Handle a dead worker + """ + self.log.info('Worker %s with pid %d is dead', worker_data.name, worker_data.pid) + with contextlib.suppress(KeyError): + self.worker_dict.pop(worker_data.name) + + def check_workers(self, respawn: bool = True) -> None: + """ + Check whether workers are still alive + """ + self.log.debug('Checking worker processes') + self.reap_workers() + # If we have less number of workers than num_workers, + # respawn the difference + if respawn and self.status != self.Status.STOPPED: + delta = self.num_workers - len(self.worker_dict) + if delta: + for i in range(delta): + self.start_worker(burst=self._burst, _sleep=self._sleep) + + def start_worker( + self, + count: Optional[int] = None, + burst: bool = True, + _sleep: float = 0, + logging_level: str = "INFO", + ): + """ + Starts a worker and adds the data to worker_datas. + * sleep: waits for X seconds before creating worker, for testing purposes + """ + name = uuid4().hex + process = Process( + target=run_worker, + args=(name, self._queue_names, self._connection_class, self._connection_kwargs), + kwargs={ + '_sleep': _sleep, + 'burst': burst, + 'logging_level': logging_level, + 'worker_class': self.worker_class, + 'job_class': self.job_class, + 'serializer': self.serializer, + }, + name=f'Worker {name} (WorkerPool {self.name})', + ) + process.start() + worker_data = WorkerData(name=name, pid=process.pid, process=process) # type: ignore + self.worker_dict[name] = worker_data + self.log.debug('Spawned worker: %s with PID %d', name, process.pid) + + def start_workers(self, burst: bool = True, _sleep: float = 0, logging_level: str = "INFO"): + """ + Run the workers + * sleep: waits for X seconds before creating worker, only for testing purposes + """ + self.log.debug(f'Spawning {self.num_workers} workers') + for i in range(self.num_workers): + self.start_worker(i + 1, burst=burst, _sleep=_sleep, logging_level=logging_level) + + def stop_worker(self, worker_data: WorkerData, sig=signal.SIGINT): + """ + Send stop signal to worker and catch "No such process" error if the worker is already dead. + """ + try: + os.kill(worker_data.pid, sig) + self.log.info('Sent shutdown command to worker with %s', worker_data.pid) + except OSError as e: + if e.errno == errno.ESRCH: + # "No such process" is fine with us + self.log.debug('Horse already dead') + else: + raise + + def stop_workers(self): + """Send SIGINT to all workers""" + self.log.info('Sending stop signal to %s workers', len(self.worker_dict)) + worker_datas = list(self.worker_dict.values()) + for worker_data in worker_datas: + self.stop_worker(worker_data) + + def start(self, burst: bool = False, logging_level: str = "INFO"): + self._burst = burst + respawn = not burst # Don't respawn workers if burst mode is on + setup_loghandlers(logging_level, DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__) + self.log.info(f'Starting worker pool {self.name} with pid %d...', os.getpid()) + self.status = self.Status.IDLE + self.start_workers(burst=self._burst, logging_level=logging_level) + self._install_signal_handlers() + while True: + if self.status == self.Status.STOPPED: + if self.all_workers_have_stopped(): + self.log.info('All workers stopped, exiting...') + break + else: + self.log.info('Waiting for workers to shutdown...') + time.sleep(1) + continue + else: + self.check_workers(respawn=respawn) + if burst and self.number_of_active_workers == 0: + self.log.info('All workers stopped, exiting...') + break + + time.sleep(1) + + +def run_worker( + worker_name: str, + queue_names: List[str], + connection_class, + connection_kwargs: dict, + worker_class: Type[BaseWorker] = Worker, + serializer: Type[DefaultSerializer] = DefaultSerializer, + job_class: Type[Job] = Job, + burst: bool = True, + logging_level: str = "INFO", + _sleep: int = 0, +): + connection = connection_class(**connection_kwargs) + queues = [Queue(name, connection=connection) for name in queue_names] + worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class) + worker.log.info("Starting worker started with PID %s", os.getpid()) + time.sleep(_sleep) + worker.work(burst=burst, logging_level=logging_level) diff --git a/tests/__init__.py b/tests/__init__.py index 1f1cc26..36b2bc6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -32,6 +32,34 @@ def ssl_test(f): return unittest.skipUnless(os.environ.get('RUN_SSL_TESTS'), "SSL tests disabled")(f) +class TestCase(unittest.TestCase): + """Base class to inherit test cases from for RQ. + + It sets up the Redis connection (available via self.connection), turns off + logging to the terminal and flushes the Redis database before and after + running each test. + """ + + @classmethod + def setUpClass(cls): + # Set up connection to Redis + cls.connection = find_empty_redis_database() + # Shut up logging + logging.disable(logging.ERROR) + + def setUp(self): + # Flush beforewards (we like our hygiene) + self.connection.flushdb() + + def tearDown(self): + # Flush afterwards + self.connection.flushdb() + + @classmethod + def tearDownClass(cls): + logging.disable(logging.NOTSET) + + class RQTestCase(unittest.TestCase): """Base class to inherit test cases from for RQ. @@ -65,6 +93,7 @@ class RQTestCase(unittest.TestCase): # Implement assertIsNotNone for Python runtimes < 2.7 or < 3.1 if not hasattr(unittest.TestCase, 'assertIsNotNone'): + def assertIsNotNone(self, value, *args): # noqa self.assertNotEqual(value, None, *args) @@ -74,5 +103,6 @@ class RQTestCase(unittest.TestCase): # Pop the connection to Redis testconn = pop_connection() - assert testconn == cls.testconn, \ - 'Wow, something really nasty happened to the Redis connection stack. Check your setup.' + assert ( + testconn == cls.testconn + ), 'Wow, something really nasty happened to the Redis connection stack. Check your setup.' diff --git a/tests/fixtures.py b/tests/fixtures.py index e75fb68..4536c3c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -13,7 +13,9 @@ from multiprocessing import Process from redis import Redis from rq import Connection, get_current_job, get_current_connection, Queue +from rq.command import send_kill_horse_command, send_shutdown_command from rq.decorators import job +from rq.job import Job from rq.worker import HerokuWorker, Worker @@ -153,6 +155,7 @@ class ClassWithAStaticMethod: with Connection(): + @job(queue='default') def decorated_job(x, y): return x + y @@ -210,7 +213,7 @@ class DummyQueue: pass -def kill_worker(pid, double_kill, interval=0.5): +def kill_worker(pid: int, double_kill: bool, interval: float = 1.5): # wait for the worker to be started over on the main process time.sleep(interval) os.kill(pid, signal.SIGTERM) @@ -286,3 +289,18 @@ def save_exception(job, connection, type, value, traceback): def erroneous_callback(job): """A callback that's not written properly""" pass + + +def _send_shutdown_command(worker_name, connection_kwargs, delay=0.25): + time.sleep(delay) + send_shutdown_command(Redis(**connection_kwargs), worker_name) + + +def _send_kill_horse_command(worker_name, connection_kwargs, delay=0.25): + """Waits delay before sending kill-horse command""" + time.sleep(delay) + send_kill_horse_command(Redis(**connection_kwargs), worker_name) + + +class CustomJob(Job): + """A custom job class just to test it""" diff --git a/tests/test_cli.py b/tests/test_cli.py index daa118b..79ac12d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,7 +12,7 @@ from redis import Redis from rq import Queue from rq.cli import main from rq.cli.helpers import read_config_file, CliConfig, parse_function_arg, parse_schedule -from rq.job import Job +from rq.job import Job, JobStatus from rq.registry import FailedJobRegistry, ScheduledJobRegistry from rq.serializers import JSONSerializer from rq.timeouts import UnixSignalDeathPenalty @@ -25,8 +25,25 @@ from tests import RQTestCase from tests.fixtures import div_by_zero, say_hello -class TestRQCli(RQTestCase): +class CLITestCase(RQTestCase): + def setUp(self): + super().setUp() + db_num = self.testconn.connection_pool.connection_kwargs['db'] + self.redis_url = 'redis://127.0.0.1:6379/%d' % db_num + self.connection = Redis.from_url(self.redis_url) + + def assert_normal_execution(self, result): + if result.exit_code == 0: + return True + else: + print("Non normal execution") + print("Exit Code: {}".format(result.exit_code)) + print("Output: {}".format(result.output)) + print("Exception: {}".format(result.exception)) + self.assertEqual(result.exit_code, 0) + +class TestRQCli(CLITestCase): @pytest.fixture(autouse=True) def set_tmpdir(self, tmpdir): self.tmpdir = tmpdir @@ -42,12 +59,9 @@ class TestRQCli(RQTestCase): self.assertEqual(result.exit_code, 0) """Test rq_cli script""" + def setUp(self): super().setUp() - db_num = self.testconn.connection_pool.connection_kwargs['db'] - self.redis_url = 'redis://127.0.0.1:6379/%d' % db_num - self.connection = Redis.from_url(self.redis_url) - job = Job.create(func=div_by_zero, args=(1, 2, 3)) job.origin = 'fake' job.save() @@ -76,18 +90,9 @@ class TestRQCli(RQTestCase): cli_config.connection.connection_pool.connection_kwargs['host'], 'testhost.example.com', ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['port'], - 6379 - ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['db'], - 0 - ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['password'], - None - ) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['port'], 6379) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['db'], 0) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['password'], None) def test_config_file_default_options_override(self): """""" @@ -97,18 +102,9 @@ class TestRQCli(RQTestCase): cli_config.connection.connection_pool.connection_kwargs['host'], 'testhost.example.com', ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['port'], - 6378 - ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['db'], - 2 - ) - self.assertEqual( - cli_config.connection.connection_pool.connection_kwargs['password'], - '123' - ) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['port'], 6378) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['db'], 2) + self.assertEqual(cli_config.connection.connection_pool.connection_kwargs['password'], '123') def test_config_env_vars(self): os.environ['REDIS_HOST'] = "testhost.example.com" @@ -123,18 +119,12 @@ class TestRQCli(RQTestCase): def test_death_penalty_class(self): cli_config = CliConfig() - self.assertEqual( - UnixSignalDeathPenalty, - cli_config.death_penalty_class - ) + self.assertEqual(UnixSignalDeathPenalty, cli_config.death_penalty_class) cli_config = CliConfig(death_penalty_class='rq.job.Job') - self.assertEqual( - Job, - cli_config.death_penalty_class - ) + self.assertEqual(Job, cli_config.death_penalty_class) - with self.assertRaises(BadParameter): + with self.assertRaises(ValueError): CliConfig(death_penalty_class='rq.abcd') def test_empty_nothing(self): @@ -163,10 +153,7 @@ class TestRQCli(RQTestCase): self.assertIn(job2, registry) self.assertIn(job3, registry) - result = runner.invoke( - main, - ['requeue', '-u', self.redis_url, '--queue', 'requeue', job.id] - ) + result = runner.invoke(main, ['requeue', '-u', self.redis_url, '--queue', 'requeue', job.id]) self.assert_normal_execution(result) # Only the first specified job is requeued @@ -174,10 +161,7 @@ class TestRQCli(RQTestCase): self.assertIn(job2, registry) self.assertIn(job3, registry) - result = runner.invoke( - main, - ['requeue', '-u', self.redis_url, '--queue', 'requeue', '--all'] - ) + result = runner.invoke(main, ['requeue', '-u', self.redis_url, '--queue', 'requeue', '--all']) self.assert_normal_execution(result) # With --all flag, all failed jobs are requeued self.assertNotIn(job2, registry) @@ -203,8 +187,7 @@ class TestRQCli(RQTestCase): self.assertIn(job3, registry) result = runner.invoke( - main, - ['requeue', '-u', self.redis_url, '--queue', 'requeue', '-S', 'rq.serializers.JSONSerializer', job.id] + main, ['requeue', '-u', self.redis_url, '--queue', 'requeue', '-S', 'rq.serializers.JSONSerializer', job.id] ) self.assert_normal_execution(result) @@ -215,7 +198,7 @@ class TestRQCli(RQTestCase): result = runner.invoke( main, - ['requeue', '-u', self.redis_url, '--queue', 'requeue', '-S', 'rq.serializers.JSONSerializer', '--all'] + ['requeue', '-u', self.redis_url, '--queue', 'requeue', '-S', 'rq.serializers.JSONSerializer', '--all'], ) self.assert_normal_execution(result) # With --all flag, all failed jobs are requeued @@ -257,8 +240,7 @@ class TestRQCli(RQTestCase): self.assert_normal_execution(result) self.assertIn('0 workers, 0 queue', result.output) - result = runner.invoke(main, ['info', '--by-queue', - '-u', self.redis_url, '--only-workers']) + result = runner.invoke(main, ['info', '--by-queue', '-u', self.redis_url, '--only-workers']) self.assert_normal_execution(result) self.assertIn('0 workers, 0 queue', result.output) @@ -288,14 +270,12 @@ class TestRQCli(RQTestCase): worker_2.register_birth() worker_2.set_state(WorkerStatus.BUSY) - result = runner.invoke(main, ['info', 'foo', 'bar', - '-u', self.redis_url, '--only-workers']) + result = runner.invoke(main, ['info', 'foo', 'bar', '-u', self.redis_url, '--only-workers']) self.assert_normal_execution(result) self.assertIn('2 workers, 2 queues', result.output) - result = runner.invoke(main, ['info', 'foo', 'bar', '--by-queue', - '-u', self.redis_url, '--only-workers']) + result = runner.invoke(main, ['info', 'foo', 'bar', '--by-queue', '-u', self.redis_url, '--only-workers']) self.assert_normal_execution(result) # Ensure both queues' workers are shown @@ -374,15 +354,13 @@ class TestRQCli(RQTestCase): # If disable-default-exception-handler is given, job is not moved to FailedJobRegistry job = q.enqueue(div_by_zero) - runner.invoke(main, ['worker', '-u', self.redis_url, '-b', - '--disable-default-exception-handler']) + runner.invoke(main, ['worker', '-u', self.redis_url, '-b', '--disable-default-exception-handler']) registry = FailedJobRegistry(queue=q) self.assertFalse(job in registry) # Both default and custom exception handler is run job = q.enqueue(div_by_zero) - runner.invoke(main, ['worker', '-u', self.redis_url, '-b', - '--exception-handler', 'tests.fixtures.add_meta']) + runner.invoke(main, ['worker', '-u', self.redis_url, '-b', '--exception-handler', 'tests.fixtures.add_meta']) registry = FailedJobRegistry(queue=q) self.assertTrue(job in registry) job.refresh() @@ -390,9 +368,18 @@ class TestRQCli(RQTestCase): # Only custom exception handler is run job = q.enqueue(div_by_zero) - runner.invoke(main, ['worker', '-u', self.redis_url, '-b', - '--exception-handler', 'tests.fixtures.add_meta', - '--disable-default-exception-handler']) + runner.invoke( + main, + [ + 'worker', + '-u', + self.redis_url, + '-b', + '--exception-handler', + 'tests.fixtures.add_meta', + '--disable-default-exception-handler', + ], + ) registry = FailedJobRegistry(queue=q) self.assertFalse(job in registry) job.refresh() @@ -400,8 +387,8 @@ class TestRQCli(RQTestCase): def test_suspend_and_resume(self): """rq suspend -u - rq worker -u -b - rq resume -u + rq worker -u -b + rq resume -u """ runner = CliRunner() result = runner.invoke(main, ['suspend', '-u', self.redis_url]) @@ -409,24 +396,19 @@ class TestRQCli(RQTestCase): result = runner.invoke(main, ['worker', '-u', self.redis_url, '-b']) self.assertEqual(result.exit_code, 1) - self.assertEqual( - result.output.strip(), - 'RQ is currently suspended, to resume job execution run "rq resume"' - ) + self.assertEqual(result.output.strip(), 'RQ is currently suspended, to resume job execution run "rq resume"') result = runner.invoke(main, ['resume', '-u', self.redis_url]) self.assert_normal_execution(result) def test_suspend_with_ttl(self): - """rq suspend -u --duration=2 - """ + """rq suspend -u --duration=2""" runner = CliRunner() result = runner.invoke(main, ['suspend', '-u', self.redis_url, '--duration', 1]) self.assert_normal_execution(result) def test_suspend_with_invalid_ttl(self): - """rq suspend -u --duration=0 - """ + """rq suspend -u --duration=0""" runner = CliRunner() result = runner.invoke(main, ['suspend', '-u', self.redis_url, '--duration', 0]) @@ -439,8 +421,7 @@ class TestRQCli(RQTestCase): q = Queue('default', connection=connection, serializer=JSONSerializer) runner = CliRunner() job = q.enqueue(say_hello) - runner.invoke(main, ['worker', '-u', self.redis_url, - '--serializer rq.serializer.JSONSerializer']) + runner.invoke(main, ['worker', '-u', self.redis_url, '--serializer rq.serializer.JSONSerializer']) self.assertIn(job.id, q.job_ids) def test_cli_enqueue(self): @@ -458,7 +439,7 @@ class TestRQCli(RQTestCase): self.assertTrue(result.output.startswith(prefix)) self.assertTrue(result.output.endswith(suffix)) - job_id = result.output[len(prefix):-len(suffix)] + job_id = result.output[len(prefix) : -len(suffix)] queue_key = 'rq:queue:default' self.assertEqual(self.connection.llen(queue_key), 1) self.assertEqual(self.connection.lrange(queue_key, 0, -1)[0].decode('ascii'), job_id) @@ -473,7 +454,9 @@ class TestRQCli(RQTestCase): self.assertTrue(queue.is_empty()) runner = CliRunner() - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '-S', 'rq.serializers.JSONSerializer', 'tests.fixtures.say_hello']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '-S', 'rq.serializers.JSONSerializer', 'tests.fixtures.say_hello'] + ) self.assert_normal_execution(result) prefix = 'Enqueued tests.fixtures.say_hello() with job-id \'' @@ -482,7 +465,7 @@ class TestRQCli(RQTestCase): self.assertTrue(result.output.startswith(prefix)) self.assertTrue(result.output.endswith(suffix)) - job_id = result.output[len(prefix):-len(suffix)] + job_id = result.output[len(prefix) : -len(suffix)] queue_key = 'rq:queue:default' self.assertEqual(self.connection.llen(queue_key), 1) self.assertEqual(self.connection.lrange(queue_key, 0, -1)[0].decode('ascii'), job_id) @@ -497,9 +480,22 @@ class TestRQCli(RQTestCase): self.assertTrue(queue.is_empty()) runner = CliRunner() - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.echo', 'hello', - ':[1, {"key": "value"}]', ':@tests/test.json', '%1, 2', 'json:=[3.0, true]', - 'nojson=abc', 'file=@tests/test.json']) + result = runner.invoke( + main, + [ + 'enqueue', + '-u', + self.redis_url, + 'tests.fixtures.echo', + 'hello', + ':[1, {"key": "value"}]', + ':@tests/test.json', + '%1, 2', + 'json:=[3.0, true]', + 'nojson=abc', + 'file=@tests/test.json', + ], + ) self.assert_normal_execution(result) job_id = self.connection.lrange('rq:queue:default', 0, -1)[0].decode('ascii') @@ -523,8 +519,9 @@ class TestRQCli(RQTestCase): self.assertTrue(len(registry) == 0) runner = CliRunner() - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', - '--schedule-in', '10s']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', '--schedule-in', '10s'] + ) self.assert_normal_execution(result) scheduler.acquire_locks() @@ -559,8 +556,9 @@ class TestRQCli(RQTestCase): self.assertTrue(len(registry) == 0) runner = CliRunner() - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', - '--schedule-at', '2021-01-01T00:00:00']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', '--schedule-at', '2021-01-01T00:00:00'] + ) self.assert_normal_execution(result) scheduler.acquire_locks() @@ -578,8 +576,9 @@ class TestRQCli(RQTestCase): self.assertTrue(len(queue) == 0) self.assertTrue(len(registry) == 0) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', - '--schedule-at', '2100-01-01T00:00:00']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', '--schedule-at', '2100-01-01T00:00:00'] + ) self.assert_normal_execution(result) self.assertTrue(len(queue) == 0) @@ -599,12 +598,28 @@ class TestRQCli(RQTestCase): self.assertTrue(queue.is_empty()) runner = CliRunner() - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.say_hello', '--retry-max', '3', - '--retry-interval', '10', '--retry-interval', '20', '--retry-interval', '40']) + result = runner.invoke( + main, + [ + 'enqueue', + '-u', + self.redis_url, + 'tests.fixtures.say_hello', + '--retry-max', + '3', + '--retry-interval', + '10', + '--retry-interval', + '20', + '--retry-interval', + '40', + ], + ) self.assert_normal_execution(result) - job = Job.fetch(self.connection.lrange('rq:queue:default', 0, -1)[0].decode('ascii'), - connection=self.connection) + job = Job.fetch( + self.connection.lrange('rq:queue:default', 0, -1)[0].decode('ascii'), connection=self.connection + ) self.assertEqual(job.retries_left, 3) self.assertEqual(job.retry_intervals, [10, 20, 40]) @@ -627,8 +642,9 @@ class TestRQCli(RQTestCase): self.assertNotEqual(result.exit_code, 0) self.assertIn('Unable to parse 1. non keyword argument as JSON.', result.output) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.echo', - '%invalid_eval_statement']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.echo', '%invalid_eval_statement'] + ) self.assertNotEqual(result.exit_code, 0) self.assertIn('Unable to eval 1. non keyword argument as Python object.', result.output) @@ -636,8 +652,19 @@ class TestRQCli(RQTestCase): self.assertNotEqual(result.exit_code, 0) self.assertIn('You can\'t specify multiple values for the same keyword.', result.output) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, 'tests.fixtures.echo', '--schedule-in', '1s', - '--schedule-at', '2000-01-01T00:00:00']) + result = runner.invoke( + main, + [ + 'enqueue', + '-u', + self.redis_url, + 'tests.fixtures.echo', + '--schedule-in', + '1s', + '--schedule-at', + '2000-01-01T00:00:00', + ], + ) self.assertNotEqual(result.exit_code, 0) self.assertIn('You can\'t specify both --schedule-in and --schedule-at', result.output) @@ -678,19 +705,25 @@ class TestRQCli(RQTestCase): self.assertEqual((job.args, job.kwargs), (['abc'], {})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'abc=def']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'abc=def'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'abc': 'def'})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', ':{"json": "abc"}']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', ':{"json": "abc"}'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([{'json': 'abc'}], {})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key:={"json": "abc"}']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key:={"json": "abc"}'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'key': {'json': 'abc'}})) @@ -714,37 +747,99 @@ class TestRQCli(RQTestCase): self.assertEqual((job.args, job.kwargs), ([True], {})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key%=(1, 2)']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key%=(1, 2)'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'key': (1, 2)})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key%={"foo": True}']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key%={"foo": True}'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'key': {"foo": True}})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', '@tests/test.json']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', '@tests/test.json'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([open('tests/test.json', 'r').read()], {})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key=@tests/test.json']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key=@tests/test.json'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'key': open('tests/test.json', 'r').read()})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', ':@tests/test.json']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', ':@tests/test.json'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([json.loads(open('tests/test.json', 'r').read())], {})) id = str(uuid4()) - result = runner.invoke(main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key:=@tests/test.json']) + result = runner.invoke( + main, ['enqueue', '-u', self.redis_url, '--job-id', id, 'tests.fixtures.echo', 'key:=@tests/test.json'] + ) self.assert_normal_execution(result) job = Job.fetch(id) self.assertEqual((job.args, job.kwargs), ([], {'key': json.loads(open('tests/test.json', 'r').read())})) + + +class WorkerPoolCLITestCase(CLITestCase): + def test_worker_pool_burst_and_num_workers(self): + """rq worker-pool -u -b -n 3""" + runner = CliRunner() + result = runner.invoke(main, ['worker-pool', '-u', self.redis_url, '-b', '-n', '3']) + self.assert_normal_execution(result) + + def test_serializer_and_queue_argument(self): + """rq worker-pool foo bar -u -b""" + queue = Queue('foo', connection=self.connection, serializer=JSONSerializer) + job = queue.enqueue(say_hello, 'Hello') + queue = Queue('bar', connection=self.connection, serializer=JSONSerializer) + job_2 = queue.enqueue(say_hello, 'Hello') + runner = CliRunner() + result = runner.invoke( + main, + ['worker-pool', 'foo', 'bar', '-u', self.redis_url, '-b', '--serializer', 'rq.serializers.JSONSerializer'], + ) + self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED) + self.assertEqual(job_2.get_status(refresh=True), JobStatus.FINISHED) + + def test_worker_class_argument(self): + """rq worker-pool -u -b --worker-class rq.Worker""" + runner = CliRunner() + result = runner.invoke(main, ['worker-pool', '-u', self.redis_url, '-b', '--worker-class', 'rq.Worker']) + self.assert_normal_execution(result) + result = runner.invoke( + main, ['worker-pool', '-u', self.redis_url, '-b', '--worker-class', 'rq.worker.SimpleWorker'] + ) + self.assert_normal_execution(result) + + # This one fails because the worker class doesn't exist + result = runner.invoke( + main, ['worker-pool', '-u', self.redis_url, '-b', '--worker-class', 'rq.worker.NonExistantWorker'] + ) + self.assertNotEqual(result.exit_code, 0) + + def test_job_class_argument(self): + """rq worker-pool -u -b --job-class rq.job.Job""" + runner = CliRunner() + result = runner.invoke(main, ['worker-pool', '-u', self.redis_url, '-b', '--job-class', 'rq.job.Job']) + self.assert_normal_execution(result) + + # This one fails because Job class doesn't exist + result = runner.invoke( + main, ['worker-pool', '-u', self.redis_url, '-b', '--job-class', 'rq.job.NonExistantJob'] + ) + self.assertNotEqual(result.exit_code, 0) diff --git a/tests/test_commands.py b/tests/test_commands.py index 786fc26..f98a0ec 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -5,7 +5,7 @@ from multiprocessing import Process from redis import Redis from tests import RQTestCase -from tests.fixtures import long_running_job +from tests.fixtures import long_running_job, _send_kill_horse_command, _send_shutdown_command from rq import Queue, Worker from rq.command import send_command, send_kill_horse_command, send_shutdown_command, send_stop_job_command @@ -14,41 +14,25 @@ from rq.serializers import JSONSerializer from rq.worker import WorkerStatus -def _send_shutdown_command(worker_name, connection_kwargs): - time.sleep(0.25) - send_shutdown_command(Redis(**connection_kwargs), worker_name) - - -def _send_kill_horse_command(worker_name, connection_kwargs): - """Waits 0.25 seconds before sending kill-horse command""" - time.sleep(0.25) - send_kill_horse_command(Redis(**connection_kwargs), worker_name) - - def start_work(queue_name, worker_name, connection_kwargs): - worker = Worker(queue_name, name=worker_name, - connection=Redis(**connection_kwargs)) + worker = Worker(queue_name, name=worker_name, connection=Redis(**connection_kwargs)) worker.work() def start_work_burst(queue_name, worker_name, connection_kwargs): - worker = Worker(queue_name, name=worker_name, - connection=Redis(**connection_kwargs), - serializer=JSONSerializer) + worker = Worker(queue_name, name=worker_name, connection=Redis(**connection_kwargs), serializer=JSONSerializer) worker.work(burst=True) - class TestCommands(RQTestCase): - def test_shutdown_command(self): """Ensure that shutdown command works properly.""" connection = self.testconn worker = Worker('foo', connection=connection) - p = Process(target=_send_shutdown_command, - args=(worker.name, - connection.connection_pool.connection_kwargs.copy())) + p = Process( + target=_send_shutdown_command, args=(worker.name, connection.connection_pool.connection_kwargs.copy()) + ) p.start() worker.work() p.join(1) @@ -60,18 +44,16 @@ class TestCommands(RQTestCase): job = queue.enqueue(long_running_job, 4) worker = Worker('foo', connection=connection) - p = Process(target=_send_kill_horse_command, - args=(worker.name, - connection.connection_pool.connection_kwargs.copy())) + p = Process( + target=_send_kill_horse_command, args=(worker.name, connection.connection_pool.connection_kwargs.copy()) + ) p.start() worker.work(burst=True) p.join(1) job.refresh() self.assertTrue(job.id in queue.failed_job_registry) - p = Process(target=start_work, - args=('foo', worker.name, - connection.connection_pool.connection_kwargs.copy())) + p = Process(target=start_work, args=('foo', worker.name, connection.connection_pool.connection_kwargs.copy())) p.start() p.join(2) @@ -97,9 +79,9 @@ class TestCommands(RQTestCase): with self.assertRaises(NoSuchJobError): send_stop_job_command(connection, job_id='1', serializer=JSONSerializer) - p = Process(target=start_work_burst, - args=('foo', worker.name, - connection.connection_pool.connection_kwargs.copy())) + p = Process( + target=start_work_burst, args=('foo', worker.name, connection.connection_pool.connection_kwargs.copy()) + ) p.start() p.join(1) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4b4ba8e..0b64d2b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,7 +29,7 @@ class TestConnectionInheritance(RQTestCase): def test_connection_pass_thru(self): """Connection passed through from queues to jobs.""" - q1 = Queue() + q1 = Queue(connection=self.testconn) with Connection(new_connection()): q2 = Queue() job1 = q1.enqueue(do_nothing) diff --git a/tests/test_job.py b/tests/test_job.py index 23bbd11..318c41b 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -13,9 +13,14 @@ from rq.utils import as_text from rq.exceptions import DeserializationError, InvalidJobOperation, NoSuchJobError from rq.job import Job, JobStatus, Dependency, cancel_job, get_current_job, Callback from rq.queue import Queue -from rq.registry import (CanceledJobRegistry, DeferredJobRegistry, FailedJobRegistry, - FinishedJobRegistry, StartedJobRegistry, - ScheduledJobRegistry) +from rq.registry import ( + CanceledJobRegistry, + DeferredJobRegistry, + FailedJobRegistry, + FinishedJobRegistry, + StartedJobRegistry, + ScheduledJobRegistry, +) from rq.utils import utcformat, utcnow from rq.worker import Worker from tests import RQTestCase, fixtures @@ -164,10 +169,10 @@ class TestJob(RQTestCase): def test_fetch(self): """Fetching jobs.""" # Prepare test - self.testconn.hset('rq:job:some_id', 'data', - "(S'tests.fixtures.some_calculation'\nN(I3\nI4\nt(dp1\nS'z'\nI2\nstp2\n.") - self.testconn.hset('rq:job:some_id', 'created_at', - '2012-02-07T22:13:24.123456Z') + self.testconn.hset( + 'rq:job:some_id', 'data', "(S'tests.fixtures.some_calculation'\nN(I3\nI4\nt(dp1\nS'z'\nI2\nstp2\n." + ) + self.testconn.hset('rq:job:some_id', 'created_at', '2012-02-07T22:13:24.123456Z') # Fetch returns a job job = Job.fetch('some_id') @@ -211,9 +216,18 @@ class TestJob(RQTestCase): # ... and no other keys are stored self.assertEqual( - {b'created_at', b'data', b'description', b'ended_at', b'last_heartbeat', b'started_at', - b'worker_name', b'success_callback_name', b'failure_callback_name'}, - set(self.testconn.hkeys(job.key)) + { + b'created_at', + b'data', + b'description', + b'ended_at', + b'last_heartbeat', + b'started_at', + b'worker_name', + b'success_callback_name', + b'failure_callback_name', + }, + set(self.testconn.hkeys(job.key)), ) self.assertEqual(job.last_heartbeat, None) @@ -245,9 +259,11 @@ class TestJob(RQTestCase): def test_persistence_of_callbacks(self): """Storing jobs with success and/or failure callbacks.""" - job = Job.create(func=fixtures.some_calculation, - on_success=Callback(fixtures.say_hello, timeout=10), - on_failure=fixtures.say_pid) # deprecated callable + job = Job.create( + func=fixtures.some_calculation, + on_success=Callback(fixtures.say_hello, timeout=10), + on_failure=fixtures.say_pid, + ) # deprecated callable job.save() stored_job = Job.fetch(job.id) @@ -257,8 +273,7 @@ class TestJob(RQTestCase): self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout) # None(s) - job = Job.create(func=fixtures.some_calculation, - on_failure=None) + job = Job.create(func=fixtures.some_calculation, on_failure=None) job.save() stored_job = Job.fetch(job.id) self.assertIsNone(stored_job.success_callback) @@ -270,8 +285,7 @@ class TestJob(RQTestCase): def test_store_then_fetch(self): """Store, then fetch.""" - job = Job.create(func=fixtures.some_calculation, timeout='1h', args=(3, 4), - kwargs=dict(z=2)) + job = Job.create(func=fixtures.some_calculation, timeout='1h', args=(3, 4), kwargs=dict(z=2)) job.save() job2 = Job.fetch(job.id) @@ -291,8 +305,7 @@ class TestJob(RQTestCase): def test_fetching_unreadable_data(self): """Fetching succeeds on unreadable data, but lazy props fail.""" # Set up - job = Job.create(func=fixtures.some_calculation, args=(3, 4), - kwargs=dict(z=2)) + job = Job.create(func=fixtures.some_calculation, args=(3, 4), kwargs=dict(z=2)) job.save() # Just replace the data hkey with some random noise @@ -317,7 +330,7 @@ class TestJob(RQTestCase): self.testconn.hset(job.key, 'data', zlib.compress(unimportable_data)) job.refresh() - with self.assertRaises(AttributeError): + with self.assertRaises(ValueError): job.func # accessing the func property should fail def test_compressed_exc_info_handling(self): @@ -330,10 +343,7 @@ class TestJob(RQTestCase): # exc_info is stored in compressed format exc_info = self.testconn.hget(job.key, 'exc_info') - self.assertEqual( - as_text(zlib.decompress(exc_info)), - exception_string - ) + self.assertEqual(as_text(zlib.decompress(exc_info)), exception_string) job.refresh() self.assertEqual(job.exc_info, exception_string) @@ -352,10 +362,7 @@ class TestJob(RQTestCase): # Job data is stored in compressed format job_data = job.data - self.assertEqual( - zlib.compress(job_data), - self.testconn.hget(job.key, 'data') - ) + self.assertEqual(zlib.compress(job_data), self.testconn.hget(job.key, 'data')) self.testconn.hset(job.key, 'data', job_data) job.refresh() @@ -415,10 +422,7 @@ class TestJob(RQTestCase): job._result = queue.Queue() job.save() - self.assertEqual( - self.testconn.hget(job.key, 'result').decode('utf-8'), - 'Unserializable return value' - ) + self.assertEqual(self.testconn.hget(job.key, 'result').decode('utf-8'), 'Unserializable return value') job = Job.fetch(job.id) self.assertEqual(job.result, 'Unserializable return value') @@ -449,8 +453,7 @@ class TestJob(RQTestCase): def test_description_is_persisted(self): """Ensure that job's custom description is set properly""" - job = Job.create(func=fixtures.say_hello, args=('Lionel',), - description='Say hello!') + job = Job.create(func=fixtures.say_hello, args=('Lionel',), description='Say hello!') job.save() Job.fetch(job.id, connection=self.testconn) self.assertEqual(job.description, 'Say hello!') @@ -606,7 +609,6 @@ class TestJob(RQTestCase): self.assertRaises(NoSuchJobError, Job.fetch, job.id, self.testconn) def test_cleanup_expires_dependency_keys(self): - dependency_job = Job.create(func=fixtures.say_hello) dependency_job.save() @@ -653,8 +655,13 @@ class TestJob(RQTestCase): def test_job_delete_removes_itself_from_registries(self): """job.delete() should remove itself from job registries""" - job = Job.create(func=fixtures.say_hello, status=JobStatus.FAILED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.FAILED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = FailedJobRegistry(connection=self.testconn, serializer=JSONSerializer) registry.add(job, 500) @@ -662,8 +669,13 @@ class TestJob(RQTestCase): job.delete() self.assertFalse(job in registry) - job = Job.create(func=fixtures.say_hello, status=JobStatus.STOPPED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.STOPPED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = FailedJobRegistry(connection=self.testconn, serializer=JSONSerializer) registry.add(job, 500) @@ -671,8 +683,13 @@ class TestJob(RQTestCase): job.delete() self.assertFalse(job in registry) - job = Job.create(func=fixtures.say_hello, status=JobStatus.FINISHED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.FINISHED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = FinishedJobRegistry(connection=self.testconn, serializer=JSONSerializer) @@ -681,8 +698,13 @@ class TestJob(RQTestCase): job.delete() self.assertFalse(job in registry) - job = Job.create(func=fixtures.say_hello, status=JobStatus.STARTED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.STARTED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = StartedJobRegistry(connection=self.testconn, serializer=JSONSerializer) @@ -691,8 +713,13 @@ class TestJob(RQTestCase): job.delete() self.assertFalse(job in registry) - job = Job.create(func=fixtures.say_hello, status=JobStatus.DEFERRED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.DEFERRED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = DeferredJobRegistry(connection=self.testconn, serializer=JSONSerializer) @@ -701,8 +728,13 @@ class TestJob(RQTestCase): job.delete() self.assertFalse(job in registry) - job = Job.create(func=fixtures.say_hello, status=JobStatus.SCHEDULED, - connection=self.testconn, origin='default', serializer=JSONSerializer) + job = Job.create( + func=fixtures.say_hello, + status=JobStatus.SCHEDULED, + connection=self.testconn, + origin='default', + serializer=JSONSerializer, + ) job.save() registry = ScheduledJobRegistry(connection=self.testconn, serializer=JSONSerializer) @@ -764,7 +796,6 @@ class TestJob(RQTestCase): self.assertNotIn(job.id, queue.get_job_ids()) def test_dependent_job_creates_dependencies_key(self): - queue = Queue(connection=self.testconn) dependency_job = queue.enqueue(fixtures.say_hello) dependent_job = Job.create(func=fixtures.say_hello, depends_on=dependency_job) @@ -818,8 +849,7 @@ class TestJob(RQTestCase): """test call string with unicode keyword arguments""" queue = Queue(connection=self.testconn) - job = queue.enqueue(fixtures.echo, - arg_with_unicode=fixtures.UnicodeStringObject()) + job = queue.enqueue(fixtures.echo, arg_with_unicode=fixtures.UnicodeStringObject()) self.assertIsNotNone(job.get_call_string()) job.perform() @@ -875,10 +905,7 @@ class TestJob(RQTestCase): # Second cancel should fail self.assertRaisesRegex( - InvalidJobOperation, - r'Cannot cancel already canceled job: fake_job_id', - cancel_job, - job.id + InvalidJobOperation, r'Cannot cancel already canceled job: fake_job_id', cancel_job, job.id ) def test_create_and_cancel_job_enqueue_dependents(self): @@ -1030,12 +1057,7 @@ class TestJob(RQTestCase): dependency_job.delete() - self.assertNotIn( - dependent_job.id, - [job.id for job in dependent_job.fetch_dependencies( - pipeline=self.testconn - )] - ) + self.assertNotIn(dependent_job.id, [job.id for job in dependent_job.fetch_dependencies(pipeline=self.testconn)]) def test_fetch_dependencies_watches(self): queue = Queue(connection=self.testconn) @@ -1046,10 +1068,7 @@ class TestJob(RQTestCase): dependent_job.save() with self.testconn.pipeline() as pipeline: - dependent_job.fetch_dependencies( - watch=True, - pipeline=pipeline - ) + dependent_job.fetch_dependencies(watch=True, pipeline=pipeline) pipeline.multi() @@ -1061,10 +1080,7 @@ class TestJob(RQTestCase): def test_dependencies_finished_returns_false_if_dependencies_queued(self): queue = Queue(connection=self.testconn) - dependency_job_ids = [ - queue.enqueue(fixtures.say_hello).id - for _ in range(5) - ] + dependency_job_ids = [queue.enqueue(fixtures.say_hello).id for _ in range(5)] dependent_job = Job.create(func=fixtures.say_hello) dependent_job._dependency_ids = dependency_job_ids @@ -1083,10 +1099,7 @@ class TestJob(RQTestCase): self.assertTrue(dependencies_finished) def test_dependencies_finished_returns_true_if_all_dependencies_finished(self): - dependency_jobs = [ - Job.create(fixtures.say_hello) - for _ in range(5) - ] + dependency_jobs = [Job.create(fixtures.say_hello) for _ in range(5)] dependent_job = Job.create(func=fixtures.say_hello) dependent_job._dependency_ids = [job.id for job in dependency_jobs] diff --git a/tests/test_queue.py b/tests/test_queue.py index 2143486..d352736 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -5,17 +5,18 @@ from unittest.mock import patch from rq import Retry, Queue from rq.job import Job, JobStatus -from rq.registry import (CanceledJobRegistry, DeferredJobRegistry, FailedJobRegistry, - FinishedJobRegistry, ScheduledJobRegistry, - StartedJobRegistry) +from rq.registry import ( + CanceledJobRegistry, + DeferredJobRegistry, + FailedJobRegistry, + FinishedJobRegistry, + ScheduledJobRegistry, + StartedJobRegistry, +) from rq.worker import Worker from tests import RQTestCase -from tests.fixtures import echo, say_hello - - -class CustomJob(Job): - pass +from tests.fixtures import CustomJob, echo, say_hello class MultipleDependencyJob(Job): @@ -23,6 +24,7 @@ class MultipleDependencyJob(Job): Allows for the patching of `_dependency_ids` to simulate multi-dependency support without modifying the public interface of `Job` """ + create_job = Job.create @classmethod @@ -193,9 +195,7 @@ class TestQueue(RQTestCase): # Inspect data inside Redis q_key = 'rq:queue:default' self.assertEqual(self.testconn.llen(q_key), 1) - self.assertEqual( - self.testconn.lrange(q_key, 0, -1)[0].decode('ascii'), - job_id) + self.assertEqual(self.testconn.lrange(q_key, 0, -1)[0].decode('ascii'), job_id) def test_enqueue_sets_metadata(self): """Enqueueing job onto queues modifies meta data.""" @@ -246,19 +246,13 @@ class TestQueue(RQTestCase): self.assertEqual(queue, fooq) self.assertEqual(job.func, say_hello) self.assertEqual(job.origin, fooq.name) - self.assertEqual( - job.args[0], 'for Foo', - 'Foo should be dequeued first.' - ) + self.assertEqual(job.args[0], 'for Foo', 'Foo should be dequeued first.') job, queue = Queue.dequeue_any([fooq, barq], None) self.assertEqual(queue, barq) self.assertEqual(job.func, say_hello) self.assertEqual(job.origin, barq.name) - self.assertEqual( - job.args[0], 'for Bar', - 'Bar should be dequeued second.' - ) + self.assertEqual(job.args[0], 'for Bar', 'Bar should be dequeued second.') def test_dequeue_any_ignores_nonexisting_jobs(self): """Dequeuing (from any queue) silently ignores non-existing jobs.""" @@ -269,10 +263,7 @@ class TestQueue(RQTestCase): # Dequeue simply ignores the missing job and returns None self.assertEqual(q.count, 1) - self.assertEqual( - Queue.dequeue_any([Queue(), Queue('low')], None), # noqa - None - ) + self.assertEqual(Queue.dequeue_any([Queue(), Queue('low')], None), None) # noqa self.assertEqual(q.count, 0) def test_enqueue_with_ttl(self): @@ -352,10 +343,7 @@ class TestQueue(RQTestCase): job = q.enqueue(echo, 1, job_timeout=1, result_ttl=1, bar='baz') self.assertEqual(job.timeout, 1) self.assertEqual(job.result_ttl, 1) - self.assertEqual( - job.perform(), - ((1,), {'bar': 'baz'}) - ) + self.assertEqual(job.perform(), ((1,), {'bar': 'baz'})) # Explicit kwargs mode kwargs = { @@ -365,20 +353,14 @@ class TestQueue(RQTestCase): job = q.enqueue(echo, job_timeout=2, result_ttl=2, args=[1], kwargs=kwargs) self.assertEqual(job.timeout, 2) self.assertEqual(job.result_ttl, 2) - self.assertEqual( - job.perform(), - ((1,), {'timeout': 1, 'result_ttl': 1}) - ) + self.assertEqual(job.perform(), ((1,), {'timeout': 1, 'result_ttl': 1})) # Explicit args and kwargs should also work with enqueue_at time = datetime.now(timezone.utc) + timedelta(seconds=10) job = q.enqueue_at(time, echo, job_timeout=2, result_ttl=2, args=[1], kwargs=kwargs) self.assertEqual(job.timeout, 2) self.assertEqual(job.result_ttl, 2) - self.assertEqual( - job.perform(), - ((1,), {'timeout': 1, 'result_ttl': 1}) - ) + self.assertEqual(job.perform(), ((1,), {'timeout': 1, 'result_ttl': 1})) # Positional arguments is not allowed if explicit args and kwargs are used self.assertRaises(Exception, q.enqueue, echo, 1, kwargs=kwargs) @@ -447,10 +429,7 @@ class TestQueue(RQTestCase): parent_job.set_status(JobStatus.FINISHED) - self.assertEqual( - set(registry.get_job_ids()), - set([job_1.id, job_2.id]) - ) + self.assertEqual(set(registry.get_job_ids()), set([job_1.id, job_2.id])) # After dependents is enqueued, job_1 and job_2 should be in queue self.assertEqual(q.job_ids, []) q.enqueue_dependents(parent_job) @@ -472,18 +451,12 @@ class TestQueue(RQTestCase): # Each queue has its own DeferredJobRegistry registry_1 = DeferredJobRegistry(q_1.name, connection=self.testconn) - self.assertEqual( - set(registry_1.get_job_ids()), - set([job_1.id]) - ) + self.assertEqual(set(registry_1.get_job_ids()), set([job_1.id])) registry_2 = DeferredJobRegistry(q_2.name, connection=self.testconn) parent_job.set_status(JobStatus.FINISHED) - self.assertEqual( - set(registry_2.get_job_ids()), - set([job_2.id]) - ) + self.assertEqual(set(registry_2.get_job_ids()), set([job_2.id])) # After dependents is enqueued, job_1 on queue_1 and # job_2 should be in queue_2 @@ -569,21 +542,9 @@ class TestQueue(RQTestCase): (but at_front still applies)""" # Job with unfinished dependency is not immediately enqueued q = Queue() - job_1_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_1', - at_front=False - ) - job_2_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_2', - at_front=False - ) - job_3_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_3', - at_front=True - ) + job_1_data = Queue.prepare_data(say_hello, job_id='fake_job_id_1', at_front=False) + job_2_data = Queue.prepare_data(say_hello, job_id='fake_job_id_2', at_front=False) + job_3_data = Queue.prepare_data(say_hello, job_id='fake_job_id_3', at_front=True) jobs = q.enqueue_many( [job_1_data, job_2_data, job_3_data], ) @@ -599,25 +560,10 @@ class TestQueue(RQTestCase): # Job with unfinished dependency is not immediately enqueued q = Queue() with q.connection.pipeline() as pipe: - job_1_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_1', - at_front=False - ) - job_2_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_2', - at_front=False - ) - job_3_data = Queue.prepare_data( - say_hello, - job_id='fake_job_id_3', - at_front=True - ) - jobs = q.enqueue_many( - [job_1_data, job_2_data, job_3_data], - pipeline=pipe - ) + job_1_data = Queue.prepare_data(say_hello, job_id='fake_job_id_1', at_front=False) + job_2_data = Queue.prepare_data(say_hello, job_id='fake_job_id_2', at_front=False) + job_3_data = Queue.prepare_data(say_hello, job_id='fake_job_id_3', at_front=True) + jobs = q.enqueue_many([job_1_data, job_2_data, job_3_data], pipeline=pipe) self.assertEqual(q.job_ids, []) for job in jobs: self.assertEqual(job.get_status(refresh=False), JobStatus.QUEUED) @@ -660,7 +606,6 @@ class TestQueue(RQTestCase): self.assertEqual(job.timeout, 123) def test_enqueue_job_with_multiple_queued_dependencies(self): - parent_jobs = [Job.create(func=say_hello) for _ in range(2)] for job in parent_jobs: @@ -669,14 +614,12 @@ class TestQueue(RQTestCase): q = Queue() with patch('rq.queue.Job.create', new=MultipleDependencyJob.create): - job = q.enqueue(say_hello, depends_on=parent_jobs[0], - _dependency_ids=[job.id for job in parent_jobs]) + job = q.enqueue(say_hello, depends_on=parent_jobs[0], _dependency_ids=[job.id for job in parent_jobs]) self.assertEqual(job.get_status(), JobStatus.DEFERRED) self.assertEqual(q.job_ids, []) self.assertEqual(job.fetch_dependencies(), parent_jobs) def test_enqueue_job_with_multiple_finished_dependencies(self): - parent_jobs = [Job.create(func=say_hello) for _ in range(2)] for job in parent_jobs: @@ -685,16 +628,13 @@ class TestQueue(RQTestCase): q = Queue() with patch('rq.queue.Job.create', new=MultipleDependencyJob.create): - job = q.enqueue(say_hello, depends_on=parent_jobs[0], - _dependency_ids=[job.id for job in parent_jobs]) + job = q.enqueue(say_hello, depends_on=parent_jobs[0], _dependency_ids=[job.id for job in parent_jobs]) self.assertEqual(job.get_status(), JobStatus.QUEUED) self.assertEqual(q.job_ids, [job.id]) self.assertEqual(job.fetch_dependencies(), parent_jobs) def test_enqueues_dependent_if_other_dependencies_finished(self): - - parent_jobs = [Job.create(func=say_hello) for _ in - range(3)] + parent_jobs = [Job.create(func=say_hello) for _ in range(3)] parent_jobs[0]._status = JobStatus.STARTED parent_jobs[0].save() @@ -706,11 +646,11 @@ class TestQueue(RQTestCase): parent_jobs[2].save() q = Queue() - with patch('rq.queue.Job.create', - new=MultipleDependencyJob.create): + with patch('rq.queue.Job.create', new=MultipleDependencyJob.create): # dependent job deferred, b/c parent_job 0 is still 'started' - dependent_job = q.enqueue(say_hello, depends_on=parent_jobs[0], - _dependency_ids=[job.id for job in parent_jobs]) + dependent_job = q.enqueue( + say_hello, depends_on=parent_jobs[0], _dependency_ids=[job.id for job in parent_jobs] + ) self.assertEqual(dependent_job.get_status(), JobStatus.DEFERRED) # now set parent job 0 to 'finished' @@ -721,7 +661,6 @@ class TestQueue(RQTestCase): self.assertEqual(q.job_ids, [dependent_job.id]) def test_does_not_enqueue_dependent_if_other_dependencies_not_finished(self): - started_dependency = Job.create(func=say_hello, status=JobStatus.STARTED) started_dependency.save() @@ -730,8 +669,11 @@ class TestQueue(RQTestCase): q = Queue() with patch('rq.queue.Job.create', new=MultipleDependencyJob.create): - dependent_job = q.enqueue(say_hello, depends_on=[started_dependency], - _dependency_ids=[started_dependency.id, queued_dependency.id]) + dependent_job = q.enqueue( + say_hello, + depends_on=[started_dependency], + _dependency_ids=[started_dependency.id, queued_dependency.id], + ) self.assertEqual(dependent_job.get_status(), JobStatus.DEFERRED) q.enqueue_dependents(started_dependency) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c417554..96cde1c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -30,7 +30,6 @@ class CustomRedisConnection(redis.Connection): class TestScheduledJobRegistry(RQTestCase): - def test_get_jobs_to_enqueue(self): """Getting job ids to enqueue from ScheduledJobRegistry.""" queue = Queue(connection=self.testconn) @@ -42,8 +41,7 @@ class TestScheduledJobRegistry(RQTestCase): self.testconn.zadd(registry.key, {'baz': timestamp + 30}) self.assertEqual(registry.get_jobs_to_enqueue(), ['foo']) - self.assertEqual(registry.get_jobs_to_enqueue(timestamp + 20), - ['foo', 'bar']) + self.assertEqual(registry.get_jobs_to_enqueue(timestamp + 20), ['foo', 'bar']) def test_get_jobs_to_schedule_with_chunk_size(self): """Max amount of jobs returns by get_jobs_to_schedule() equal to chunk_size""" @@ -55,10 +53,8 @@ class TestScheduledJobRegistry(RQTestCase): for index in range(0, chunk_size * 2): self.testconn.zadd(registry.key, {'foo_{}'.format(index): 1}) - self.assertEqual(len(registry.get_jobs_to_schedule(timestamp, chunk_size)), - chunk_size) - self.assertEqual(len(registry.get_jobs_to_schedule(timestamp, chunk_size * 2)), - chunk_size * 2) + self.assertEqual(len(registry.get_jobs_to_schedule(timestamp, chunk_size)), chunk_size) + self.assertEqual(len(registry.get_jobs_to_schedule(timestamp, chunk_size * 2)), chunk_size * 2) def test_get_scheduled_time(self): """get_scheduled_time() returns job's scheduled datetime""" @@ -106,8 +102,9 @@ class TestScheduledJobRegistry(RQTestCase): mock_atz = mock.patch('time.altzone', 14400) with mock_tz, mock_day, mock_atz: registry.schedule(job, datetime(2019, 1, 1)) - self.assertEqual(self.testconn.zscore(registry.key, job.id), - 1546300800 + 18000) # 2019-01-01 UTC in Unix timestamp + self.assertEqual( + self.testconn.zscore(registry.key, job.id), 1546300800 + 18000 + ) # 2019-01-01 UTC in Unix timestamp # second, time.daylight != 0 (in DST) # mock the sitatuoin for American/New_York not in DST (UTC - 4) @@ -119,20 +116,19 @@ class TestScheduledJobRegistry(RQTestCase): mock_atz = mock.patch('time.altzone', 14400) with mock_tz, mock_day, mock_atz: registry.schedule(job, datetime(2019, 1, 1)) - self.assertEqual(self.testconn.zscore(registry.key, job.id), - 1546300800 + 14400) # 2019-01-01 UTC in Unix timestamp + self.assertEqual( + self.testconn.zscore(registry.key, job.id), 1546300800 + 14400 + ) # 2019-01-01 UTC in Unix timestamp # Score is always stored in UTC even if datetime is in a different tz tz = timezone(timedelta(hours=7)) job = Job.create('myfunc', connection=self.testconn) job.save() registry.schedule(job, datetime(2019, 1, 1, 7, tzinfo=tz)) - self.assertEqual(self.testconn.zscore(registry.key, job.id), - 1546300800) # 2019-01-01 UTC in Unix timestamp + self.assertEqual(self.testconn.zscore(registry.key, job.id), 1546300800) # 2019-01-01 UTC in Unix timestamp class TestScheduler(RQTestCase): - def test_init(self): """Scheduler can be instantiated with queues or queue names""" foo_queue = Queue('foo', connection=self.testconn) @@ -209,7 +205,12 @@ class TestScheduler(RQTestCase): def test_queue_scheduler_pid(self): queue = Queue(connection=self.testconn) - scheduler = RQScheduler([queue, ], connection=self.testconn) + scheduler = RQScheduler( + [ + queue, + ], + connection=self.testconn, + ) scheduler.acquire_locks() assert queue.scheduler_pid == os.getpid() @@ -276,12 +277,11 @@ class TestScheduler(RQTestCase): scheduler.prepare_registries([foo_queue.name, bar_queue.name]) self.assertEqual( scheduler._scheduled_job_registries, - [ScheduledJobRegistry(queue=foo_queue), ScheduledJobRegistry(queue=bar_queue)] + [ScheduledJobRegistry(queue=foo_queue), ScheduledJobRegistry(queue=bar_queue)], ) class TestWorker(RQTestCase): - def test_work_burst(self): """worker.work() with scheduler enabled works properly""" queue = Queue(connection=self.testconn) @@ -363,10 +363,7 @@ class TestWorker(RQTestCase): p = Process(target=kill_worker, args=(os.getpid(), False, 5)) p.start() - queue.enqueue_at( - datetime(2019, 1, 1, tzinfo=timezone.utc), - say_hello, meta={'foo': 'bar'} - ) + queue.enqueue_at(datetime(2019, 1, 1, tzinfo=timezone.utc), say_hello, meta={'foo': 'bar'}) worker.work(burst=False, with_scheduler=True) p.join(1) self.assertIsNotNone(worker.scheduler) @@ -375,7 +372,6 @@ class TestWorker(RQTestCase): class TestQueue(RQTestCase): - def test_enqueue_at(self): """queue.enqueue_at() puts job in the scheduled""" queue = Queue(connection=self.testconn) @@ -398,7 +394,7 @@ class TestQueue(RQTestCase): def test_enqueue_at_at_front(self): """queue.enqueue_at() accepts at_front argument. When true, job will be put at position 0 - of the queue when the time comes for the job to be scheduled""" + of the queue when the time comes for the job to be scheduled""" queue = Queue(connection=self.testconn) registry = ScheduledJobRegistry(queue=queue) scheduler = RQScheduler([queue], connection=self.testconn) @@ -432,12 +428,10 @@ class TestQueue(RQTestCase): now = datetime.now(timezone.utc) scheduled_time = registry.get_scheduled_time(job) # Ensure that job is scheduled roughly 30 seconds from now - self.assertTrue( - now + timedelta(seconds=28) < scheduled_time < now + timedelta(seconds=32) - ) + self.assertTrue(now + timedelta(seconds=28) < scheduled_time < now + timedelta(seconds=32)) def test_enqueue_in_with_retry(self): - """ Ensure that the retry parameter is passed + """Ensure that the retry parameter is passed to the enqueue_at function from enqueue_in. """ queue = Queue(connection=self.testconn) diff --git a/tests/test_sentry.py b/tests/test_sentry.py index e63a5f6..f52f7db 100644 --- a/tests/test_sentry.py +++ b/tests/test_sentry.py @@ -19,7 +19,6 @@ class FakeSentry: class TestSentry(RQTestCase): - def setUp(self): super().setUp() db_num = self.testconn.connection_pool.connection_kwargs['db'] @@ -35,13 +34,16 @@ class TestSentry(RQTestCase): """rq worker -u -b --exception-handler """ # connection = Redis.from_url(self.redis_url) runner = CliRunner() - runner.invoke(main, ['worker', '-u', self.redis_url, '-b', - '--sentry-dsn', 'https://1@sentry.io/1']) + runner.invoke(main, ['worker', '-u', self.redis_url, '-b', '--sentry-dsn', 'https://1@sentry.io/1']) self.assertEqual(mocked.call_count, 1) + runner.invoke(main, ['worker-pool', '-u', self.redis_url, '-b', '--sentry-dsn', 'https://1@sentry.io/1']) + self.assertEqual(mocked.call_count, 2) + def test_failure_capture(self): """Test failure is captured by Sentry SDK""" from sentry_sdk import Hub + hub = Hub.current self.assertIsNone(hub.last_event_id()) queue = Queue(connection=self.testconn) diff --git a/tests/test_utils.py b/tests/test_utils.py index 64b3f64..b71e67e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,22 @@ from unittest.mock import Mock from redis import Redis from tests import RQTestCase, fixtures -from rq.utils import backend_class, ensure_list, first, get_version, is_nonstring_iterable, parse_timeout, utcparse, \ - split_list, ceildiv, get_call_string, truncate_long_string from rq.exceptions import TimeoutFormatError +from rq.utils import ( + backend_class, + ceildiv, + ensure_list, + first, + get_call_string, + get_version, + import_attribute, + is_nonstring_iterable, + parse_timeout, + utcparse, + split_list, + truncate_long_string, +) +from rq.worker import SimpleWorker class TestUtils(RQTestCase): @@ -66,8 +79,9 @@ class TestUtils(RQTestCase): self.assertEqual(fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue')) self.assertNotEqual(fixtures.say_pid, backend_class(fixtures, 'DummyQueue')) self.assertEqual(fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue', override=fixtures.DummyQueue)) - self.assertEqual(fixtures.DummyQueue, - backend_class(fixtures, 'DummyQueue', override='tests.fixtures.DummyQueue')) + self.assertEqual( + fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue', override='tests.fixtures.DummyQueue') + ) def test_get_redis_version(self): """Ensure get_version works properly""" @@ -78,12 +92,14 @@ class TestUtils(RQTestCase): class DummyRedis(Redis): def info(*args): return {'redis_version': '4.0.8'} + self.assertEqual(get_version(DummyRedis()), (4, 0, 8)) # Parses 3 digit version numbers correctly class DummyRedis(Redis): def info(*args): return {'redis_version': '3.0.7.9'} + self.assertEqual(get_version(DummyRedis()), (3, 0, 7)) def test_get_redis_version_gets_cached(self): @@ -95,6 +111,13 @@ class TestUtils(RQTestCase): self.assertEqual(get_version(redis), (4, 0, 8)) redis.info.assert_called_once() + def test_import_attribute(self): + """Ensure get_version works properly""" + self.assertEqual(import_attribute('rq.utils.get_version'), get_version) + self.assertEqual(import_attribute('rq.worker.SimpleWorker'), SimpleWorker) + self.assertRaises(ValueError, import_attribute, 'non.existent.module') + self.assertRaises(ValueError, import_attribute, 'rq.worker.WrongWorker') + def test_ceildiv_even(self): """When a number is evenly divisible by another ceildiv returns the quotient""" dividend = 12 diff --git a/tests/test_worker.py b/tests/test_worker.py index dfa0f1d..d9b6121 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -22,10 +22,23 @@ from unittest.mock import Mock from rq.defaults import DEFAULT_MAINTENANCE_TASK_INTERVAL from tests import RQTestCase, slow from tests.fixtures import ( - access_self, create_file, create_file_after_timeout, create_file_after_timeout_and_setsid, div_by_zero, do_nothing, - kill_worker, long_running_job, modify_self, modify_self_and_error, - run_dummy_heroku_worker, save_key_ttl, say_hello, say_pid, raise_exc_mock, - launch_process_within_worker_and_store_pid + access_self, + create_file, + create_file_after_timeout, + create_file_after_timeout_and_setsid, + CustomJob, + div_by_zero, + do_nothing, + kill_worker, + long_running_job, + modify_self, + modify_self_and_error, + run_dummy_heroku_worker, + save_key_ttl, + say_hello, + say_pid, + raise_exc_mock, + launch_process_within_worker_and_store_pid, ) from rq import Queue, SimpleWorker, Worker, get_current_connection @@ -40,16 +53,11 @@ from rq.worker import HerokuWorker, WorkerStatus, RoundRobinWorker, RandomWorker from rq.serializers import JSONSerializer -class CustomJob(Job): - pass - - class CustomQueue(Queue): pass class TestWorker(RQTestCase): - def test_create_worker(self): """Worker creation using various inputs.""" @@ -96,31 +104,19 @@ class TestWorker(RQTestCase): """Worker processes work, then quits.""" fooq, barq = Queue('foo'), Queue('bar') w = Worker([fooq, barq]) - self.assertEqual( - w.work(burst=True), False, - 'Did not expect any work on the queue.' - ) + self.assertEqual(w.work(burst=True), False, 'Did not expect any work on the queue.') fooq.enqueue(say_hello, name='Frank') - self.assertEqual( - w.work(burst=True), True, - 'Expected at least some work done.' - ) + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') def test_work_and_quit_custom_serializer(self): """Worker processes work, then quits.""" fooq, barq = Queue('foo', serializer=JSONSerializer), Queue('bar', serializer=JSONSerializer) w = Worker([fooq, barq], serializer=JSONSerializer) - self.assertEqual( - w.work(burst=True), False, - 'Did not expect any work on the queue.' - ) + self.assertEqual(w.work(burst=True), False, 'Did not expect any work on the queue.') fooq.enqueue(say_hello, name='Frank') - self.assertEqual( - w.work(burst=True), True, - 'Expected at least some work done.' - ) + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') def test_worker_all(self): """Worker.all() works properly""" @@ -132,10 +128,7 @@ class TestWorker(RQTestCase): w2 = Worker([foo_queue], name='w2') w2.register_birth() - self.assertEqual( - set(Worker.all(connection=foo_queue.connection)), - set([w1, w2]) - ) + self.assertEqual(set(Worker.all(connection=foo_queue.connection)), set([w1, w2])) self.assertEqual(set(Worker.all(queue=foo_queue)), set([w1, w2])) self.assertEqual(set(Worker.all(queue=bar_queue)), set([w1])) @@ -176,10 +169,7 @@ class TestWorker(RQTestCase): q = Queue('foo') w = Worker([q]) job = q.enqueue('tests.fixtures.say_hello', name='Frank') - self.assertEqual( - w.work(burst=True), True, - 'Expected at least some work done.' - ) + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') expected_result = 'Hi there, Frank!' self.assertEqual(job.result, expected_result) # Only run if Redis server supports streams @@ -197,25 +187,13 @@ class TestWorker(RQTestCase): self.assertIsNotNone(job.enqueued_at) self.assertIsNone(job.started_at) self.assertIsNone(job.ended_at) - self.assertEqual( - w.work(burst=True), True, - 'Expected at least some work done.' - ) + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') self.assertEqual(job.result, 'Hi there, Stranger!') after = utcnow() job.refresh() - self.assertTrue( - before <= job.enqueued_at <= after, - 'Not %s <= %s <= %s' % (before, job.enqueued_at, after) - ) - self.assertTrue( - before <= job.started_at <= after, - 'Not %s <= %s <= %s' % (before, job.started_at, after) - ) - self.assertTrue( - before <= job.ended_at <= after, - 'Not %s <= %s <= %s' % (before, job.ended_at, after) - ) + self.assertTrue(before <= job.enqueued_at <= after, 'Not %s <= %s <= %s' % (before, job.enqueued_at, after)) + self.assertTrue(before <= job.started_at <= after, 'Not %s <= %s <= %s' % (before, job.started_at, after)) + self.assertTrue(before <= job.ended_at <= after, 'Not %s <= %s <= %s' % (before, job.ended_at, after)) def test_work_is_unreadable(self): """Unreadable jobs are put on the failed job registry.""" @@ -241,7 +219,7 @@ class TestWorker(RQTestCase): # All set, we're going to process it w = Worker([q]) - w.work(burst=True) # should silently pass + w.work(burst=True) # should silently pass self.assertEqual(q.count, 0) failed_job_registry = FailedJobRegistry(queue=q) @@ -286,8 +264,7 @@ class TestWorker(RQTestCase): w.register_birth() self.assertEqual(str(w.pid), as_text(self.testconn.hget(w.key, 'pid'))) - self.assertEqual(w.hostname, - as_text(self.testconn.hget(w.key, 'hostname'))) + self.assertEqual(w.hostname, as_text(self.testconn.hget(w.key, 'hostname'))) last_heartbeat = self.testconn.hget(w.key, 'last_heartbeat') self.assertIsNotNone(self.testconn.hget(w.key, 'birth')) self.assertTrue(last_heartbeat is not None) @@ -346,10 +323,7 @@ class TestWorker(RQTestCase): w = Worker([q], job_monitoring_interval=5) for timeout, expected_heartbeats in [(2, 0), (7, 1), (12, 2)]: - job = q.enqueue(long_running_job, - args=(timeout,), - job_timeout=30, - result_ttl=-1) + job = q.enqueue(long_running_job, args=(timeout,), job_timeout=30, result_ttl=-1) with mock.patch.object(w, 'heartbeat', wraps=w.heartbeat) as mocked: w.execute_job(job, q) self.assertEqual(mocked.call_count, expected_heartbeats) @@ -573,8 +547,7 @@ class TestWorker(RQTestCase): self.assertTrue(job.meta['second_handler']) job = q.enqueue(div_by_zero) - w = Worker([q], exception_handlers=[first_handler, black_hole, - second_handler]) + w = Worker([q], exception_handlers=[first_handler, black_hole, second_handler]) w.work(burst=True) # second_handler is not run since it's interrupted by black_hole @@ -621,17 +594,17 @@ class TestWorker(RQTestCase): # idle for 3 seconds now = utcnow() self.assertIsNone(w.dequeue_job_and_maintain_ttl(1, max_idle_time=3)) - self.assertLess((utcnow()-now).total_seconds(), 5) # 5 for some buffer + self.assertLess((utcnow() - now).total_seconds(), 5) # 5 for some buffer # idle for 2 seconds because idle_time is less than timeout now = utcnow() self.assertIsNone(w.dequeue_job_and_maintain_ttl(3, max_idle_time=2)) - self.assertLess((utcnow()-now).total_seconds(), 4) # 4 for some buffer + self.assertLess((utcnow() - now).total_seconds(), 4) # 4 for some buffer # idle for 3 seconds because idle_time is less than two rounds of timeout now = utcnow() self.assertIsNone(w.dequeue_job_and_maintain_ttl(2, max_idle_time=3)) - self.assertLess((utcnow()-now).total_seconds(), 5) # 5 for some buffer + self.assertLess((utcnow() - now).total_seconds(), 5) # 5 for some buffer @slow # noqa def test_timeouts(self): @@ -642,9 +615,7 @@ class TestWorker(RQTestCase): w = Worker([q]) # Put it on the queue with a timeout value - res = q.enqueue(create_file_after_timeout, - args=(sentinel_file, 4), - job_timeout=1) + res = q.enqueue(create_file_after_timeout, args=(sentinel_file, 4), job_timeout=1) try: os.unlink(sentinel_file) @@ -738,10 +709,7 @@ class TestWorker(RQTestCase): self.assertEqual(self.testconn.hget(worker.key, 'current_job'), None) worker.set_current_job_id(job.id) - self.assertEqual( - worker.get_current_job_id(), - as_text(self.testconn.hget(worker.key, 'current_job')) - ) + self.assertEqual(worker.get_current_job_id(), as_text(self.testconn.hget(worker.key, 'current_job'))) self.assertEqual(worker.get_current_job(), job) def test_custom_job_class(self): @@ -781,14 +749,11 @@ class TestWorker(RQTestCase): then returns.""" fooq, barq = Queue('foo'), Queue('bar') w = SimpleWorker([fooq, barq]) - self.assertEqual(w.work(burst=True), False, - 'Did not expect any work on the queue.') + self.assertEqual(w.work(burst=True), False, 'Did not expect any work on the queue.') job = fooq.enqueue(say_pid) - self.assertEqual(w.work(burst=True), True, - 'Expected at least some work done.') - self.assertEqual(job.result, os.getpid(), - 'PID mismatch, fork() is not supposed to happen here') + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') + self.assertEqual(job.result, os.getpid(), 'PID mismatch, fork() is not supposed to happen here') def test_simpleworker_heartbeat_ttl(self): """SimpleWorker's key must last longer than job.timeout when working""" @@ -823,10 +788,8 @@ class TestWorker(RQTestCase): """Worker processes work with unicode description, then quits.""" q = Queue('foo') w = Worker([q]) - job = q.enqueue('tests.fixtures.say_hello', name='Adam', - description='你好 世界!') - self.assertEqual(w.work(burst=True), True, - 'Expected at least some work done.') + job = q.enqueue('tests.fixtures.say_hello', name='Adam', description='你好 世界!') + self.assertEqual(w.work(burst=True), True, 'Expected at least some work done.') self.assertEqual(job.result, 'Hi there, Adam!') self.assertEqual(job.description, '你好 世界!') @@ -836,13 +799,11 @@ class TestWorker(RQTestCase): q = Queue("foo") w = Worker([q]) - job = q.enqueue('tests.fixtures.say_hello', name='阿达姆', - description='你好 世界!') + job = q.enqueue('tests.fixtures.say_hello', name='阿达姆', description='你好 世界!') w.work(burst=True) self.assertEqual(job.get_status(), JobStatus.FINISHED) - job = q.enqueue('tests.fixtures.say_hello_unicode', name='阿达姆', - description='你好 世界!') + job = q.enqueue('tests.fixtures.say_hello_unicode', name='阿达姆', description='你好 世界!') w.work(burst=True) self.assertEqual(job.get_status(), JobStatus.FINISHED) @@ -1023,8 +984,7 @@ class TestWorker(RQTestCase): q = Queue() # Also make sure that previously existing metadata # persists properly - job = q.enqueue(modify_self, meta={'foo': 'bar', 'baz': 42}, - args=[{'baz': 10, 'newinfo': 'waka'}]) + job = q.enqueue(modify_self, meta={'foo': 'bar', 'baz': 42}, args=[{'baz': 10, 'newinfo': 'waka'}]) w = Worker([q]) w.work(burst=True) @@ -1041,8 +1001,7 @@ class TestWorker(RQTestCase): q = Queue() # Also make sure that previously existing metadata # persists properly - job = q.enqueue(modify_self_and_error, meta={'foo': 'bar', 'baz': 42}, - args=[{'baz': 10, 'newinfo': 'waka'}]) + job = q.enqueue(modify_self_and_error, meta={'foo': 'bar', 'baz': 42}, args=[{'baz': 10, 'newinfo': 'waka'}]) w = Worker([q]) w.work(burst=True) @@ -1166,6 +1125,20 @@ class TestWorker(RQTestCase): expected_ser.sort() self.assertEqual(sorted_ids, expected_ser) + def test_request_force_stop_ignores_consecutive_signals(self): + """Ignore signals sent within 1 second of the last signal""" + queue = Queue(connection=self.testconn) + worker = Worker([queue]) + worker._horse_pid = 1 + worker._shutdown_requested_date = utcnow() + with mock.patch.object(worker, 'kill_horse') as mocked: + worker.request_force_stop(1, frame=None) + self.assertEqual(mocked.call_count, 0) + # If signal is sent a few seconds after, kill_horse() is called + worker._shutdown_requested_date = utcnow() - timedelta(seconds=2) + with mock.patch.object(worker, 'kill_horse') as mocked: + self.assertRaises(SystemExit, worker.request_force_stop, 1, frame=None) + def test_dequeue_round_robin(self): qs = [Queue('q%d' % i) for i in range(5)] @@ -1183,9 +1156,23 @@ class TestWorker(RQTestCase): start_times.append(('q%d_%d' % (i, j), job.started_at)) sorted_by_time = sorted(start_times, key=lambda tup: tup[1]) sorted_ids = [tup[0] for tup in sorted_by_time] - expected = ['q0_0', 'q1_0', 'q2_0', 'q3_0', 'q4_0', - 'q0_1', 'q1_1', 'q2_1', 'q3_1', 'q4_1', - 'q0_2', 'q1_2', 'q2_2', 'q3_2', 'q4_2'] + expected = [ + 'q0_0', + 'q1_0', + 'q2_0', + 'q3_0', + 'q4_0', + 'q0_1', + 'q1_1', + 'q2_1', + 'q3_1', + 'q4_1', + 'q0_2', + 'q1_2', + 'q2_2', + 'q3_2', + 'q4_2', + ] self.assertEqual(expected, sorted_ids) @@ -1250,8 +1237,12 @@ class WorkerShutdownTestCase(TimeoutTestCase, RQTestCase): """Busy worker shuts down immediately on double SIGTERM signal""" fooq = Queue('foo') w = Worker(fooq) + sentinel_file = '/tmp/.rq_sentinel_cold' - fooq.enqueue(create_file_after_timeout, sentinel_file, 2) + self.assertFalse( + os.path.exists(sentinel_file), '{sentinel_file} file should not exist yet, delete that file and try again.' + ) + fooq.enqueue(create_file_after_timeout, sentinel_file, 5) self.assertFalse(w._stop_requested) p = Process(target=kill_worker, args=(os.getpid(), True)) p.start() @@ -1454,7 +1445,6 @@ class HerokuWorkerShutdownTestCase(TimeoutTestCase, RQTestCase): class TestExceptionHandlerMessageEncoding(RQTestCase): - def setUp(self): super().setUp() self.worker = Worker("foo") @@ -1476,8 +1466,7 @@ class TestRoundRobinWorker(RQTestCase): for i in range(5): for j in range(3): - qs[i].enqueue(say_pid, - job_id='q%d_%d' % (i, j)) + qs[i].enqueue(say_pid, job_id='q%d_%d' % (i, j)) w = RoundRobinWorker(qs) w.work(burst=True) @@ -1488,9 +1477,23 @@ class TestRoundRobinWorker(RQTestCase): start_times.append(('q%d_%d' % (i, j), job.started_at)) sorted_by_time = sorted(start_times, key=lambda tup: tup[1]) sorted_ids = [tup[0] for tup in sorted_by_time] - expected = ['q0_0', 'q1_0', 'q2_0', 'q3_0', 'q4_0', - 'q0_1', 'q1_1', 'q2_1', 'q3_1', 'q4_1', - 'q0_2', 'q1_2', 'q2_2', 'q3_2', 'q4_2'] + expected = [ + 'q0_0', + 'q1_0', + 'q2_0', + 'q3_0', + 'q4_0', + 'q0_1', + 'q1_1', + 'q2_1', + 'q3_1', + 'q4_1', + 'q0_2', + 'q1_2', + 'q2_2', + 'q3_2', + 'q4_2', + ] self.assertEqual(expected, sorted_ids) @@ -1500,8 +1503,7 @@ class TestRandomWorker(RQTestCase): for i in range(5): for j in range(3): - qs[i].enqueue(say_pid, - job_id='q%d_%d' % (i, j)) + qs[i].enqueue(say_pid, job_id='q%d_%d' % (i, j)) w = RandomWorker(qs) w.work(burst=True) diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py new file mode 100644 index 0000000..c836309 --- /dev/null +++ b/tests/test_worker_pool.py @@ -0,0 +1,138 @@ +import os +import signal + +from multiprocessing import Process +from time import sleep +from rq.job import JobStatus + +from tests import TestCase +from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello + +from rq.queue import Queue +from rq.serializers import JSONSerializer +from rq.worker import SimpleWorker +from rq.worker_pool import run_worker, WorkerPool + + +def wait_and_send_shutdown_signal(pid, time_to_wait=0.0): + sleep(time_to_wait) + os.kill(pid, signal.SIGTERM) + + +class TestWorkerPool(TestCase): + def test_queues(self): + """Test queue parsing""" + pool = WorkerPool(['default', 'foo'], connection=self.connection) + self.assertEqual( + set(pool.queues), {Queue('default', connection=self.connection), Queue('foo', connection=self.connection)} + ) + + # def test_spawn_workers(self): + # """Test spawning workers""" + # pool = WorkerPool(['default', 'foo'], connection=self.connection, num_workers=2) + # pool.start_workers(burst=False) + # self.assertEqual(len(pool.worker_dict.keys()), 2) + # pool.stop_workers() + + def test_check_workers(self): + """Test check_workers()""" + pool = WorkerPool(['default'], connection=self.connection, num_workers=2) + pool.start_workers(burst=False) + + # There should be two workers + pool.check_workers() + self.assertEqual(len(pool.worker_dict.keys()), 2) + + worker_data = list(pool.worker_dict.values())[0] + _send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0) + # 1 worker should be dead since we sent a shutdown command + sleep(0.2) + pool.check_workers(respawn=False) + self.assertEqual(len(pool.worker_dict.keys()), 1) + + # If we call `check_workers` with `respawn=True`, the worker should be respawned + pool.check_workers(respawn=True) + self.assertEqual(len(pool.worker_dict.keys()), 2) + + pool.stop_workers() + + def test_reap_workers(self): + """Dead workers are removed from worker_dict""" + pool = WorkerPool(['default'], connection=self.connection, num_workers=2) + pool.start_workers(burst=False) + + # There should be two workers + pool.reap_workers() + self.assertEqual(len(pool.worker_dict.keys()), 2) + + worker_data = list(pool.worker_dict.values())[0] + _send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0) + # 1 worker should be dead since we sent a shutdown command + sleep(0.2) + pool.reap_workers() + self.assertEqual(len(pool.worker_dict.keys()), 1) + pool.stop_workers() + + def test_start(self): + """Test start()""" + pool = WorkerPool(['default'], connection=self.connection, num_workers=2) + + p = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5)) + p.start() + pool.start() + self.assertEqual(pool.status, pool.Status.STOPPED) + self.assertTrue(pool.all_workers_have_stopped()) + # We need this line so the test doesn't hang + pool.stop_workers() + + def test_pool_ignores_consecutive_shutdown_signals(self): + """If two shutdown signals are sent within one second, only the first one is processed""" + # Send two shutdown signals within one second while the worker is + # working on a long running job. The job should still complete (not killed) + pool = WorkerPool(['foo'], connection=self.connection, num_workers=2) + + process_1 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5)) + process_1.start() + process_2 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5)) + process_2.start() + + queue = Queue('foo', connection=self.connection) + job = queue.enqueue(long_running_job, 1) + pool.start(burst=True) + + self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED) + # We need this line so the test doesn't hang + pool.stop_workers() + + def test_run_worker(self): + """Ensure run_worker() properly spawns a Worker""" + queue = Queue('foo', connection=self.connection) + queue.enqueue(say_hello) + run_worker( + 'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy() + ) + # Worker should have processed the job + self.assertEqual(len(queue), 0) + + def test_worker_pool_arguments(self): + """Ensure arguments are properly used to create the right workers""" + queue = Queue('foo', connection=self.connection) + job = queue.enqueue(say_hello) + pool = WorkerPool([queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker) + pool.start(burst=True) + # Worker should have processed the job + self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED) + + queue = Queue('json', connection=self.connection, serializer=JSONSerializer) + job = queue.enqueue(say_hello, 'Hello') + pool = WorkerPool( + [queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker, serializer=JSONSerializer + ) + pool.start(burst=True) + # Worker should have processed the job + self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED) + + pool = WorkerPool([queue], connection=self.connection, num_workers=2, job_class=CustomJob) + pool.start(burst=True) + # Worker should have processed the job + self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)