diff --git a/.gitignore b/.gitignore index f07314b..477ffbe 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,5 @@ .vagrant Vagrantfile .idea/ -.coverage.* +.coverage* /.cache diff --git a/rq/cli/cli.py b/rq/cli/cli.py index 27dd78c..4624009 100755 --- a/rq/cli/cli.py +++ b/rq/cli/cli.py @@ -5,24 +5,25 @@ RQ command line tool from __future__ import (absolute_import, division, print_function, unicode_literals) +from functools import update_wrapper import os import sys import click -from redis import StrictRedis from redis.exceptions import ConnectionError -from rq import Connection, get_failed_queue, Queue +from rq import Connection, get_failed_queue, __version__ as version +from rq.cli.helpers import (read_config_file, refresh, + setup_loghandlers_from_args, + show_both, show_queues, show_workers, CliConfig) from rq.contrib.legacy import cleanup_ghosts +from rq.defaults import (DEFAULT_CONNECTION_CLASS, DEFAULT_JOB_CLASS, + DEFAULT_QUEUE_CLASS, DEFAULT_WORKER_CLASS) from rq.exceptions import InvalidJobOperationError from rq.utils import import_attribute from rq.suspension import (suspend as connection_suspend, resume as connection_resume, is_suspended) -from .helpers import (get_redis_from_config, read_config_file, refresh, - setup_loghandlers_from_args, show_both, show_queues, - show_workers) - # Disable the warning that Click displays (as of Click version 5.0) when users # use unicode_literals in Python 2. @@ -30,42 +31,72 @@ from .helpers import (get_redis_from_config, read_config_file, refresh, click.disable_unicode_literals_warning = True -url_option = click.option('--url', '-u', envvar='RQ_REDIS_URL', - help='URL describing Redis connection details.') - -config_option = click.option('--config', '-c', - help='Module containing RQ settings.') - - -def connect(url, config=None, connection_class=StrictRedis): - if url: - return connection_class.from_url(url) - - settings = read_config_file(config) if config else {} - return get_redis_from_config(settings, connection_class) +shared_options = [ + click.option('--url', '-u', + envvar='RQ_REDIS_URL', + help='URL describing Redis connection details.'), + click.option('--config', '-c', + envvar='RQ_CONFIG', + help='Module containing RQ settings.'), + click.option('--worker-class', '-w', + envvar='RQ_WORKER_CLASS', + default=DEFAULT_WORKER_CLASS, + help='RQ Worker class to use'), + click.option('--job-class', '-j', + envvar='RQ_JOB_CLASS', + default=DEFAULT_JOB_CLASS, + help='RQ Job class to use'), + click.option('--queue-class', + envvar='RQ_QUEUE_CLASS', + default=DEFAULT_QUEUE_CLASS, + help='RQ Queue class to use'), + click.option('--connection-class', + envvar='RQ_CONNECTION_CLASS', + default=DEFAULT_CONNECTION_CLASS, + help='Redis client class to use'), +] + + +def pass_cli_config(func): + # add all the shared options to the command + for option in shared_options: + func = option(func) + + # pass the cli config object into the command + def wrapper(*args, **kwargs): + ctx = click.get_current_context() + cli_config = CliConfig(**kwargs) + return ctx.invoke(func, cli_config, *args[1:], **kwargs) + + return update_wrapper(wrapper, func) @click.group() +@click.version_option(version) def main(): """RQ command line tool.""" pass @main.command() -@url_option @click.option('--all', '-a', is_flag=True, help='Empty all queues') @click.argument('queues', nargs=-1) -def empty(url, all, queues): +@pass_cli_config +def empty(cli_config, all, queues, **options): """Empty given queues.""" - conn = connect(url) if all: - queues = Queue.all(connection=conn) + queues = cli_config.queue_class.all(connection=cli_config.connection, + job_class=cli_config.job_class) else: - queues = [Queue(queue, connection=conn) for queue in queues] + queues = [cli_config.queue_class(queue, + connection=cli_config.connection, + job_class=cli_config.job_class) + for queue in queues] if not queues: click.echo('Nothing to do') + sys.exit(0) for queue in queues: num_jobs = queue.empty() @@ -73,13 +104,14 @@ def empty(url, all, queues): @main.command() -@url_option @click.option('--all', '-a', is_flag=True, help='Requeue all failed jobs') @click.argument('job_ids', nargs=-1) -def requeue(url, all, job_ids): +@pass_cli_config +def requeue(cli_config, all, job_class, job_ids, **options): """Requeue failed jobs.""" - conn = connect(url) - failed_queue = get_failed_queue(connection=conn) + + failed_queue = get_failed_queue(connection=cli_config.connection, + job_class=cli_config.job_class) if all: job_ids = failed_queue.job_ids @@ -102,8 +134,6 @@ def requeue(url, all, job_ids): @main.command() -@url_option -@config_option @click.option('--path', '-P', default='.', help='Specify the import path.') @click.option('--interval', '-i', type=float, help='Updates stats every N seconds (default: don\'t poll)') @click.option('--raw', '-r', is_flag=True, help='Print only the raw numbers, no bar charts') @@ -111,7 +141,9 @@ def requeue(url, all, job_ids): @click.option('--only-workers', '-W', is_flag=True, help='Show only worker info') @click.option('--by-queue', '-R', is_flag=True, help='Shows workers by queue') @click.argument('queues', nargs=-1) -def info(url, config, path, interval, raw, only_queues, only_workers, by_queue, queues): +@pass_cli_config +def info(cli_config, path, interval, raw, only_queues, only_workers, by_queue, queues, + **options): """RQ command-line monitor.""" if path: @@ -125,8 +157,9 @@ def info(url, config, path, interval, raw, only_queues, only_workers, by_queue, func = show_both try: - with Connection(connect(url, config)): - refresh(interval, func, queues, raw, by_queue) + with Connection(cli_config.connection): + refresh(interval, func, queues, raw, by_queue, + cli_config.queue_class, cli_config.worker_class) except ConnectionError as e: click.echo(e) sys.exit(1) @@ -136,14 +169,8 @@ def info(url, config, path, interval, raw, only_queues, only_workers, by_queue, @main.command() -@url_option -@config_option @click.option('--burst', '-b', is_flag=True, help='Run in burst mode (quit after all work is done)') @click.option('--name', '-n', help='Specify a different name') -@click.option('--worker-class', '-w', default='rq.Worker', help='RQ Worker class to use') -@click.option('--job-class', '-j', default='rq.job.Job', help='RQ Job class to use') -@click.option('--queue-class', default='rq.Queue', help='RQ Queue class to use') -@click.option('--connection-class', default='redis.StrictRedis', help='Redis client class to use') @click.option('--path', '-P', default='.', help='Specify the import path.') @click.option('--results-ttl', type=int, help='Default results timeout to be used') @click.option('--worker-ttl', type=int, help='Default worker timeout to be used') @@ -153,14 +180,16 @@ def info(url, config, path, interval, raw, only_queues, only_workers, by_queue, @click.option('--exception-handler', help='Exception handler(s) to use', multiple=True) @click.option('--pid', help='Write the process ID number to a file at the specified path') @click.argument('queues', nargs=-1) -def worker(url, config, burst, name, worker_class, job_class, queue_class, connection_class, path, results_ttl, - worker_ttl, verbose, quiet, sentry_dsn, exception_handler, pid, queues): +@pass_cli_config +def worker(cli_config, burst, name, path, results_ttl, + worker_ttl, verbose, quiet, sentry_dsn, exception_handler, + pid, queues, **options): """Starts an RQ worker.""" if path: sys.path = path.split(':') + sys.path - settings = read_config_file(config) if config else {} + settings = read_config_file(cli_config.config) if cli_config.config else {} # Worker specific default arguments queues = queues or settings.get('QUEUES', ['default']) sentry_dsn = sentry_dsn or settings.get('SENTRY_DSN') @@ -171,30 +200,29 @@ def worker(url, config, burst, name, worker_class, job_class, queue_class, conne setup_loghandlers_from_args(verbose, quiet) - connection_class = import_attribute(connection_class) - conn = connect(url, config, connection_class) - cleanup_ghosts(conn) - worker_class = import_attribute(worker_class) - queue_class = import_attribute(queue_class) - exception_handlers = [] - for h in exception_handler: - exception_handlers.append(import_attribute(h)) - - if is_suspended(conn): - click.secho('RQ is currently suspended, to resume job execution run "rq resume"', fg='red') - sys.exit(1) - try: - queues = [queue_class(queue, connection=conn) for queue in queues] - w = worker_class(queues, - name=name, - connection=conn, - default_worker_ttl=worker_ttl, - default_result_ttl=results_ttl, - job_class=job_class, - queue_class=queue_class, - exception_handlers=exception_handlers or None) + cleanup_ghosts(cli_config.connection) + exception_handlers = [] + for h in exception_handler: + exception_handlers.append(import_attribute(h)) + + if is_suspended(cli_config.connection): + click.secho('RQ is currently suspended, to resume job execution run "rq resume"', fg='red') + sys.exit(1) + + queues = [cli_config.queue_class(queue, + connection=cli_config.connection, + job_class=cli_config.job_class) + for queue in queues] + worker = cli_config.worker_class(queues, + name=name, + connection=cli_config.connection, + default_worker_ttl=worker_ttl, + default_result_ttl=results_ttl, + job_class=cli_config.job_class, + queue_class=cli_config.queue_class, + exception_handlers=exception_handlers or None) # Should we configure Sentry? if sentry_dsn: @@ -202,26 +230,25 @@ def worker(url, config, burst, name, worker_class, job_class, queue_class, conne from raven.transport.http import HTTPTransport from rq.contrib.sentry import register_sentry client = Client(sentry_dsn, transport=HTTPTransport) - register_sentry(client, w) + register_sentry(client, worker) - w.work(burst=burst) + worker.work(burst=burst) except ConnectionError as e: print(e) sys.exit(1) @main.command() -@url_option -@config_option @click.option('--duration', help='Seconds you want the workers to be suspended. Default is forever.', type=int) -def suspend(url, config, duration): +@pass_cli_config +def suspend(cli_config, duration, **options): """Suspends all workers, to resume run `rq resume`""" + if duration is not None and duration < 1: click.echo("Duration must be an integer greater than 1") sys.exit(1) - connection = connect(url, config) - connection_suspend(connection, duration) + connection_suspend(cli_config.connection, duration) if duration: msg = """Suspending workers for {0} seconds. No new jobs will be started during that time, but then will @@ -232,10 +259,8 @@ def suspend(url, config, duration): @main.command() -@url_option -@config_option -def resume(url, config): +@pass_cli_config +def resume(cli_config, **options): """Resumes processing of queues, that where suspended with `rq suspend`""" - connection = connect(url, config) - connection_resume(connection) + connection_resume(cli_config.connection) click.echo("Resuming workers.") diff --git a/rq/cli/helpers.py b/rq/cli/helpers.py index da7ce7f..c271174 100644 --- a/rq/cli/helpers.py +++ b/rq/cli/helpers.py @@ -9,8 +9,10 @@ from functools import partial import click import redis from redis import StrictRedis -from rq import Queue, Worker +from rq.defaults import (DEFAULT_CONNECTION_CLASS, DEFAULT_JOB_CLASS, + DEFAULT_QUEUE_CLASS, DEFAULT_WORKER_CLASS) from rq.logutils import setup_loghandlers +from rq.utils import import_attribute from rq.worker import WorkerStatus red = partial(click.style, fg='red') @@ -81,11 +83,11 @@ def state_symbol(state): return state -def show_queues(queues, raw, by_queue): +def show_queues(queues, raw, by_queue, queue_class, worker_class): if queues: - qs = list(map(Queue, queues)) + qs = list(map(queue_class, queues)) else: - qs = Queue.all() + qs = queue_class.all() num_jobs = 0 termwidth, _ = click.get_terminal_size() @@ -116,9 +118,9 @@ def show_queues(queues, raw, by_queue): click.echo('%d queues, %d jobs total' % (len(qs), num_jobs)) -def show_workers(queues, raw, by_queue): +def show_workers(queues, raw, by_queue, queue_class, worker_class): if queues: - qs = list(map(Queue, queues)) + qs = list(map(queue_class, queues)) def any_matching_queue(worker): def queue_matches(q): @@ -126,14 +128,14 @@ def show_workers(queues, raw, by_queue): return any(map(queue_matches, worker.queues)) # Filter out workers that don't match the queue filter - ws = [w for w in Worker.all() if any_matching_queue(w)] + ws = [w for w in worker_class.all() if any_matching_queue(w)] def filter_queues(queue_names): - return [qname for qname in queue_names if Queue(qname) in qs] + return [qname for qname in queue_names if queue_class(qname) in qs] else: - qs = Queue.all() - ws = Worker.all() + qs = queue_class.all() + ws = worker_class.all() filter_queues = (lambda x: x) if not by_queue: @@ -164,11 +166,11 @@ def show_workers(queues, raw, by_queue): click.echo('%d workers, %d queues' % (len(ws), len(qs))) -def show_both(queues, raw, by_queue): - show_queues(queues, raw, by_queue) +def show_both(queues, raw, by_queue, queue_class, worker_class): + show_queues(queues, raw, by_queue, queue_class, worker_class) if not raw: click.echo('') - show_workers(queues, raw, by_queue) + show_workers(queues, raw, by_queue, queue_class, worker_class) if not raw: click.echo('') import datetime @@ -197,3 +199,43 @@ def setup_loghandlers_from_args(verbose, quiet): else: level = 'INFO' setup_loghandlers(level) + + +class CliConfig(object): + """A helper class to be used with click commands, to handle shared options""" + def __init__(self, url=None, config=None, worker_class=DEFAULT_WORKER_CLASS, + job_class=DEFAULT_JOB_CLASS, queue_class=DEFAULT_QUEUE_CLASS, + connection_class=DEFAULT_CONNECTION_CLASS, *args, **kwargs): + self._connection = None + self.url = url + self.config = config + + try: + self.worker_class = import_attribute(worker_class) + except (ImportError, AttributeError) as exc: + raise click.BadParameter(str(exc), param_hint='--worker-class') + try: + self.job_class = import_attribute(job_class) + except (ImportError, AttributeError) as exc: + raise click.BadParameter(str(exc), param_hint='--job-class') + + try: + self.queue_class = import_attribute(queue_class) + except (ImportError, AttributeError) as exc: + raise click.BadParameter(str(exc), param_hint='--queue-class') + + try: + self.connection_class = import_attribute(connection_class) + except (ImportError, AttributeError) as exc: + raise click.BadParameter(str(exc), param_hint='--connection-class') + + @property + def connection(self): + if self._connection is None: + if self.url: + self._connection = self.connection_class.from_url(self.url) + else: + settings = read_config_file(self.config) if self.config else {} + self._connection = get_redis_from_config(settings, + self.connection_class) + return self._connection diff --git a/rq/decorators.py b/rq/decorators.py index 22c3860..f96ef93 100644 --- a/rq/decorators.py +++ b/rq/decorators.py @@ -8,11 +8,15 @@ from rq.compat import string_types from .defaults import DEFAULT_RESULT_TTL from .queue import Queue +from .utils import backend_class class job(object): + queue_class = Queue + def __init__(self, queue, connection=None, timeout=None, - result_ttl=DEFAULT_RESULT_TTL, ttl=None): + result_ttl=DEFAULT_RESULT_TTL, ttl=None, + queue_class=None): """A decorator that adds a ``delay`` method to the decorated function, which in turn creates a RQ job when called. Accepts a required ``queue`` argument that can be either a ``Queue`` instance or a string @@ -25,6 +29,7 @@ class job(object): simple_add.delay(1, 2) # Puts simple_add function into queue """ self.queue = queue + self.queue_class = backend_class(self, 'queue_class', override=queue_class) self.connection = connection self.timeout = timeout self.result_ttl = result_ttl @@ -34,7 +39,8 @@ class job(object): @wraps(f) def delay(*args, **kwargs): if isinstance(self.queue, string_types): - queue = Queue(name=self.queue, connection=self.connection) + queue = self.queue_class(name=self.queue, + connection=self.connection) else: queue = self.queue depends_on = kwargs.pop('depends_on', None) diff --git a/rq/defaults.py b/rq/defaults.py index 6cf2fc3..002c44b 100644 --- a/rq/defaults.py +++ b/rq/defaults.py @@ -1,2 +1,6 @@ +DEFAULT_JOB_CLASS = 'rq.job.Job' +DEFAULT_QUEUE_CLASS = 'rq.Queue' +DEFAULT_WORKER_CLASS = 'rq.Worker' +DEFAULT_CONNECTION_CLASS = 'redis.StrictRedis' DEFAULT_WORKER_TTL = 420 DEFAULT_RESULT_TTL = 500 diff --git a/rq/job.py b/rq/job.py index b820f79..8fad8e8 100644 --- a/rq/job.py +++ b/rq/job.py @@ -16,7 +16,7 @@ from .utils import enum, import_attribute, utcformat, utcnow, utcparse try: import cPickle as pickle -except ImportError: # noqa +except ImportError: # noqa # pragma: no cover import pickle # Serialize pickle dumps using the highest pickle protocol (binary, default @@ -61,24 +61,25 @@ def cancel_job(job_id, connection=None): Job.fetch(job_id, connection=connection).cancel() -def requeue_job(job_id, connection=None): +def requeue_job(job_id, connection=None, job_class=None): """Requeues the job with the given job ID. If no such job exists, just remove the job ID from the failed queue, otherwise the job ID should refer to a failed job (i.e. it should be on the failed queue). """ from .queue import get_failed_queue - fq = get_failed_queue(connection=connection) - fq.requeue(job_id) + failed_queue = get_failed_queue(connection=connection, job_class=job_class) + return failed_queue.requeue(job_id) -def get_current_job(connection=None): +def get_current_job(connection=None, job_class=None): """Returns the Job instance that is currently being executed. If this function is invoked from outside a job context, None is returned. """ + job_class = job_class or Job job_id = _job_stack.top if job_id is None: return None - return Job.fetch(job_id, connection=connection) + return job_class.fetch(job_id, connection=connection) class Job(object): @@ -123,7 +124,7 @@ class Job(object): job._instance = func job._func_name = '__call__' else: - raise TypeError('Expected a callable or a string, but got: {}'.format(func)) + raise TypeError('Expected a callable or a string, but got: {0}'.format(func)) job._args = args job._kwargs = kwargs @@ -189,7 +190,7 @@ class Job(object): return None if hasattr(self, '_dependency'): return self._dependency - job = Job.fetch(self._dependency_id, connection=self.connection) + job = self.fetch(self._dependency_id, connection=self.connection) job.refresh() self._dependency = job return job @@ -317,8 +318,22 @@ class Job(object): self._dependency_id = None self.meta = {} - def __repr__(self): # noqa - return 'Job({0!r}, enqueued_at={1!r})'.format(self._id, self.enqueued_at) + def __repr__(self): # noqa # pragma: no cover + return '{0}({1!r}, enqueued_at={2!r})'.format(self.__class__.__name__, + self._id, + self.enqueued_at) + + def __str__(self): + return '<{0} {1}: {2}>'.format(self.__class__.__name__, + self.id, + self.description) + + # Job equality + def __eq__(self, other): # noqa + return isinstance(other, self.__class__) and self.id == other.id + + def __hash__(self): # pragma: no cover + return hash(self.id) # Data access def get_id(self): # noqa @@ -476,7 +491,8 @@ class Job(object): from .queue import Queue, get_failed_queue pipeline = self.connection._pipeline() if self.origin: - q = (get_failed_queue(connection=self.connection) + q = (get_failed_queue(connection=self.connection, + job_class=self.__class__) if self.is_failed else Queue(name=self.origin, connection=self.connection)) q.remove(self, pipeline=pipeline) @@ -563,21 +579,13 @@ class Job(object): """ from .registry import DeferredJobRegistry - registry = DeferredJobRegistry(self.origin, connection=self.connection) + registry = DeferredJobRegistry(self.origin, + connection=self.connection, + job_class=self.__class__) registry.add(self, pipeline=pipeline) connection = pipeline if pipeline is not None else self.connection - connection.sadd(Job.dependents_key_for(self._dependency_id), self.id) - - def __str__(self): - return ''.format(self.id, self.description) - - # Job equality - def __eq__(self, other): # noqa - return isinstance(other, self.__class__) and self.id == other.id - - def __hash__(self): - return hash(self.id) + connection.sadd(self.dependents_key_for(self._dependency_id), self.id) _job_stack = LocalStack() diff --git a/rq/queue.py b/rq/queue.py index 417141a..36ddd7b 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -12,12 +12,12 @@ from .defaults import DEFAULT_RESULT_TTL from .exceptions import (DequeueTimeout, InvalidJobDependency, InvalidJobOperationError, NoSuchJobError, UnpickleError) from .job import Job, JobStatus -from .utils import import_attribute, utcnow +from .utils import backend_class, import_attribute, utcnow -def get_failed_queue(connection=None): +def get_failed_queue(connection=None, job_class=None): """Returns a handle to the special failed queue.""" - return FailedQueue(connection=connection) + return FailedQueue(connection=connection, job_class=job_class) def compact(lst): @@ -32,18 +32,21 @@ class Queue(object): redis_queues_keys = 'rq:queues' @classmethod - def all(cls, connection=None): + def all(cls, connection=None, job_class=None): """Returns an iterable of all Queues. """ connection = resolve_connection(connection) def to_queue(queue_key): return cls.from_queue_key(as_text(queue_key), - connection=connection) - return [to_queue(rq_key) for rq_key in connection.smembers(cls.redis_queues_keys) if rq_key] + connection=connection, + job_class=job_class) + return [to_queue(rq_key) + for rq_key in connection.smembers(cls.redis_queues_keys) + if rq_key] @classmethod - def from_queue_key(cls, queue_key, connection=None): + def from_queue_key(cls, queue_key, connection=None, job_class=None): """Returns a Queue instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Queues by their Redis keys. @@ -52,7 +55,7 @@ class Queue(object): if not queue_key.startswith(prefix): raise ValueError('Not a valid RQ queue key: {0}'.format(queue_key)) name = queue_key[len(prefix):] - return cls(name, connection=connection) + return cls(name, connection=connection, job_class=job_class) def __init__(self, name='default', default_timeout=None, connection=None, async=True, job_class=None): @@ -63,6 +66,7 @@ class Queue(object): self._default_timeout = default_timeout self._async = async + # override class attribute job_class if one was passed if job_class is not None: if isinstance(job_class, string_types): job_class = import_attribute(job_class) @@ -201,7 +205,8 @@ class Queue(object): # modifying the dependency. In this case we simply retry if depends_on is not None: if not isinstance(depends_on, self.job_class): - depends_on = Job(id=depends_on, connection=self.connection) + depends_on = self.job_class(id=depends_on, + connection=self.connection) with self.connection._pipeline() as pipe: while True: try: @@ -324,7 +329,9 @@ class Queue(object): pipe.multi() for dependent in dependent_jobs: - registry = DeferredJobRegistry(dependent.origin, self.connection) + registry = DeferredJobRegistry(dependent.origin, + self.connection, + job_class=self.job_class) registry.remove(dependent, pipeline=pipe) if dependent.origin == self.name: self.enqueue_job(dependent, pipeline=pipe) @@ -404,7 +411,7 @@ class Queue(object): return job @classmethod - def dequeue_any(cls, queues, timeout, connection=None): + def dequeue_any(cls, queues, timeout, connection=None, job_class=None): """Class method returning the job_class instance at the front of the given set of Queues, where the order of the queues is important. @@ -415,15 +422,19 @@ class Queue(object): See the documentation of cls.lpop for the interpretation of timeout. """ + job_class = backend_class(cls, 'job_class', override=job_class) + while True: queue_keys = [q.key for q in queues] result = cls.lpop(queue_keys, timeout, connection=connection) if result is None: return None queue_key, job_id = map(as_text, result) - queue = cls.from_queue_key(queue_key, connection=connection) + queue = cls.from_queue_key(queue_key, + connection=connection, + job_class=job_class) try: - job = cls.job_class.fetch(job_id, connection=connection) + job = job_class.fetch(job_id, connection=connection) except NoSuchJobError: # Silently pass on jobs that don't exist (anymore), # and continue in the look @@ -449,19 +460,21 @@ class Queue(object): raise TypeError('Cannot compare queues to other objects') return self.name < other.name - def __hash__(self): + def __hash__(self): # pragma: no cover return hash(self.name) - def __repr__(self): # noqa - return 'Queue({0!r})'.format(self.name) + def __repr__(self): # noqa # pragma: no cover + return '{0}({1!r})'.format(self.__class__.__name__, self.name) def __str__(self): - return ''.format(self.name) + return '<{0} {1}>'.format(self.__class__.__name__, self.name) class FailedQueue(Queue): - def __init__(self, connection=None): - super(FailedQueue, self).__init__(JobStatus.FAILED, connection=connection) + def __init__(self, connection=None, job_class=None): + super(FailedQueue, self).__init__(JobStatus.FAILED, + connection=connection, + job_class=job_class) def quarantine(self, job, exc_info): """Puts the given Job in quarantine (i.e. put it on the failed @@ -496,5 +509,7 @@ class FailedQueue(Queue): job.set_status(JobStatus.QUEUED) job.exc_info = None - q = Queue(job.origin, connection=self.connection) - q.enqueue_job(job) + queue = Queue(job.origin, + connection=self.connection, + job_class=self.job_class) + return queue.enqueue_job(job) diff --git a/rq/registry.py b/rq/registry.py index 7e0fcaf..16de87d 100644 --- a/rq/registry.py +++ b/rq/registry.py @@ -3,7 +3,7 @@ from .connections import resolve_connection from .exceptions import NoSuchJobError from .job import Job, JobStatus from .queue import FailedQueue -from .utils import current_timestamp +from .utils import backend_class, current_timestamp class BaseRegistry(object): @@ -12,12 +12,14 @@ class BaseRegistry(object): Each job is stored as a key in the registry, scored by expiration time (unix timestamp). """ + job_class = Job key_template = 'rq:registry:{0}' - def __init__(self, name='default', connection=None): + def __init__(self, name='default', connection=None, job_class=None): self.name = name self.key = self.key_template.format(name) self.connection = resolve_connection(connection) + self.job_class = backend_class(self, 'job_class', override=job_class) def __len__(self): """Returns the number of jobs in this registry""" @@ -81,12 +83,14 @@ class StartedJobRegistry(BaseRegistry): job_ids = self.get_expired_job_ids(score) if job_ids: - failed_queue = FailedQueue(connection=self.connection) + failed_queue = FailedQueue(connection=self.connection, + job_class=self.job_class) with self.connection.pipeline() as pipeline: for job_id in job_ids: try: - job = Job.fetch(job_id, connection=self.connection) + job = self.job_class.fetch(job_id, + connection=self.connection) job.set_status(JobStatus.FAILED) job.save(pipeline=pipeline) failed_queue.push_job_id(job_id, pipeline=pipeline) @@ -132,7 +136,11 @@ class DeferredJobRegistry(BaseRegistry): def clean_registries(queue): """Cleans StartedJobRegistry and FinishedJobRegistry of a queue.""" - registry = FinishedJobRegistry(name=queue.name, connection=queue.connection) + registry = FinishedJobRegistry(name=queue.name, + connection=queue.connection, + job_class=queue.job_class) registry.cleanup() - registry = StartedJobRegistry(name=queue.name, connection=queue.connection) + registry = StartedJobRegistry(name=queue.name, + connection=queue.connection, + job_class=queue.job_class) registry.cleanup() diff --git a/rq/utils.py b/rq/utils.py index fd94abd..7558100 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -232,3 +232,13 @@ def enum(name, *sequential, **named): # On Python 3 it does not matter, so we'll use str(), which acts as # a no-op. return type(str(name), (), values) + + +def backend_class(holder, default_name, override=None): + """Get a backend class using its default attribute name or an override""" + if override is None: + return getattr(holder, default_name) + elif isinstance(override, string_types): + return import_attribute(override) + else: + return override diff --git a/rq/worker.py b/rq/worker.py index 4549b55..69ebae0 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -28,8 +28,8 @@ from .queue import Queue, get_failed_queue from .registry import FinishedJobRegistry, StartedJobRegistry, clean_registries from .suspension import is_suspended from .timeouts import UnixSignalDeathPenalty -from .utils import (ensure_list, enum, import_attribute, make_colorizer, - utcformat, utcnow, utcparse) +from .utils import (backend_class, ensure_list, enum, + make_colorizer, utcformat, utcnow, utcparse) from .version import VERSION try: @@ -93,18 +93,22 @@ class Worker(object): job_class = Job @classmethod - def all(cls, connection=None): + def all(cls, connection=None, job_class=None, queue_class=None): """Returns an iterable of all Workers. """ if connection is None: connection = get_current_connection() reported_working = connection.smembers(cls.redis_workers_keys) - workers = [cls.find_by_key(as_text(key), connection) + workers = [cls.find_by_key(as_text(key), + connection=connection, + job_class=job_class, + queue_class=queue_class) for key in reported_working] return compact(workers) @classmethod - def find_by_key(cls, worker_key, connection=None): + def find_by_key(cls, worker_key, connection=None, job_class=None, + queue_class=None): """Returns a Worker instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Workers by their Redis keys. @@ -120,12 +124,18 @@ class Worker(object): return None name = worker_key[len(prefix):] - worker = cls([], name, connection=connection) + worker = cls([], + name, + connection=connection, + job_class=job_class, + queue_class=queue_class) queues = as_text(connection.hget(worker.key, 'queues')) worker._state = as_text(connection.hget(worker.key, 'state') or '?') worker._job_id = connection.hget(worker.key, 'current_job') or None if queues: - worker.queues = [cls.queue_class(queue, connection=connection) + worker.queues = [worker.queue_class(queue, + connection=connection, + job_class=job_class) for queue in queues.split(',')] return worker @@ -137,17 +147,12 @@ class Worker(object): connection = get_current_connection() self.connection = connection - if job_class is not None: - if isinstance(job_class, string_types): - job_class = import_attribute(job_class) - self.job_class = job_class + self.job_class = backend_class(self, 'job_class', override=job_class) + self.queue_class = backend_class(self, 'queue_class', override=queue_class) - if queue_class is not None: - if isinstance(queue_class, string_types): - queue_class = import_attribute(queue_class) - self.queue_class = queue_class - - queues = [self.queue_class(name=q, connection=connection) + queues = [self.queue_class(name=q, + connection=connection, + job_class=self.job_class) if isinstance(q, string_types) else q for q in ensure_list(queues)] self._name = name @@ -168,7 +173,8 @@ class Worker(object): self._horse_pid = 0 self._stop_requested = False self.log = logger - self.failed_queue = get_failed_queue(connection=self.connection) + self.failed_queue = get_failed_queue(connection=self.connection, + job_class=self.job_class) self.last_cleaned_at = None # By default, push the "move-to-failed-queue" exception handler onto @@ -488,7 +494,8 @@ class Worker(object): try: result = self.queue_class.dequeue_any(self.queues, timeout, - connection=self.connection) + connection=self.connection, + job_class=self.job_class) if result is not None: job, queue = result self.log.info('{0}: {1} ({2})'.format(green(queue.name), @@ -544,9 +551,7 @@ class Worker(object): # Job completed and its ttl has expired break if job_status not in [JobStatus.FINISHED, JobStatus.FAILED]: - self.handle_job_failure( - job=job - ) + self.handle_job_failure(job=job) # Unhandled failure: move the job to the failed queue self.log.warning( @@ -620,7 +625,9 @@ class Worker(object): self.set_state(WorkerStatus.BUSY, pipeline=pipeline) self.set_current_job_id(job.id, pipeline=pipeline) self.heartbeat(timeout, pipeline=pipeline) - registry = StartedJobRegistry(job.origin, self.connection) + registry = StartedJobRegistry(job.origin, + self.connection, + job_class=self.job_class) registry.add(job, timeout, pipeline=pipeline) job.set_status(JobStatus.STARTED, pipeline=pipeline) self.connection._hset(job.key, 'started_at', @@ -630,11 +637,7 @@ class Worker(object): msg = 'Processing {0} from {1} since {2}' self.procline(msg.format(job.func_name, job.origin, time.time())) - def handle_job_failure( - self, - job, - started_job_registry=None - ): + def handle_job_failure(self, job, started_job_registry=None): """Handles the failure or an executing job by: 1. Setting the job status to failed 2. Removing the job from the started_job_registry @@ -643,10 +646,9 @@ class Worker(object): with self.connection._pipeline() as pipeline: if started_job_registry is None: - started_job_registry = StartedJobRegistry( - job.origin, - self.connection - ) + started_job_registry = StartedJobRegistry(job.origin, + self.connection, + job_class=self.job_class) job.set_status(JobStatus.FAILED, pipeline=pipeline) started_job_registry.remove(job, pipeline=pipeline) self.set_current_job_id(None, pipeline=pipeline) @@ -657,12 +659,7 @@ class Worker(object): # even if Redis is down pass - def handle_job_success( - self, - job, - queue, - started_job_registry - ): + def handle_job_success(self, job, queue, started_job_registry): with self.connection._pipeline() as pipeline: while True: try: @@ -680,7 +677,8 @@ class Worker(object): job.save(pipeline=pipeline) finished_job_registry = FinishedJobRegistry(job.origin, - self.connection) + self.connection, + job_class=self.job_class) finished_job_registry.add(job, result_ttl, pipeline) job.cleanup(result_ttl, pipeline=pipeline, @@ -700,7 +698,9 @@ class Worker(object): push_connection(self.connection) - started_job_registry = StartedJobRegistry(job.origin, self.connection) + started_job_registry = StartedJobRegistry(job.origin, + self.connection, + job_class=self.job_class) try: with self.death_penalty_class(job.timeout or self.queue_class.DEFAULT_TIMEOUT): @@ -712,16 +712,12 @@ class Worker(object): # to use the same exc handling when pickling fails job._result = rv - self.handle_job_success( - job=job, - queue=queue, - started_job_registry=started_job_registry - ) + self.handle_job_success(job=job, + queue=queue, + started_job_registry=started_job_registry) except Exception: - self.handle_job_failure( - job=job, - started_job_registry=started_job_registry - ) + self.handle_job_failure(job=job, + started_job_registry=started_job_registry) self.handle_exception(job, *sys.exc_info()) return False diff --git a/tests/test_cli.py b/tests/test_cli.py index 36036b7..dc11502 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,7 +7,8 @@ from rq import get_failed_queue, Queue from rq.compat import is_python_version from rq.job import Job from rq.cli import main -from rq.cli.helpers import read_config_file +from rq.cli.helpers import read_config_file, CliConfig +import pytest from tests import RQTestCase from tests.fixtures import div_by_zero @@ -18,15 +19,12 @@ else: from unittest2 import TestCase # noqa -class TestCommandLine(TestCase): - def test_config_file(self): - settings = read_config_file("tests.dummy_settings") - self.assertIn("REDIS_HOST", settings) - self.assertEqual(settings['REDIS_HOST'], "testhost.example.com") - - class TestRQCli(RQTestCase): + @pytest.fixture(autouse=True) + def set_tmpdir(self, tmpdir): + self.tmpdir = tmpdir + def assert_normal_execution(self, result): if result.exit_code == 0: return True @@ -48,13 +46,43 @@ class TestRQCli(RQTestCase): job.save() get_failed_queue().quarantine(job, Exception('Some fake error')) # noqa - def test_empty(self): + def test_config_file(self): + settings = read_config_file('tests.dummy_settings') + self.assertIn('REDIS_HOST', settings) + self.assertEqual(settings['REDIS_HOST'], 'testhost.example.com') + + def test_config_file_option(self): + """""" + cli_config = CliConfig(config='tests.dummy_settings') + self.assertEqual( + cli_config.connection.connection_pool.connection_kwargs['host'], + 'testhost.example.com', + ) + runner = CliRunner() + result = runner.invoke(main, ['info', '--config', cli_config.config]) + self.assertEqual(result.exit_code, 1) + + def test_empty_nothing(self): + """rq empty -u """ + runner = CliRunner() + result = runner.invoke(main, ['empty', '-u', self.redis_url]) + self.assert_normal_execution(result) + self.assertEqual(result.output.strip(), 'Nothing to do') + + def test_empty_failed(self): """rq empty -u failed""" runner = CliRunner() result = runner.invoke(main, ['empty', '-u', self.redis_url, 'failed']) self.assert_normal_execution(result) self.assertEqual(result.output.strip(), '1 jobs removed from failed queue') + def test_empty_all(self): + """rq empty -u failed --all""" + runner = CliRunner() + result = runner.invoke(main, ['empty', '-u', self.redis_url, '--all']) + self.assert_normal_execution(result) + self.assertEqual(result.output.strip(), '1 jobs removed from failed queue') + def test_requeue(self): """rq requeue -u --all""" runner = CliRunner() @@ -62,6 +90,10 @@ class TestRQCli(RQTestCase): self.assert_normal_execution(result) self.assertEqual(result.output.strip(), 'Requeueing 1 jobs from failed queue') + result = runner.invoke(main, ['requeue', '-u', self.redis_url, '--all']) + self.assert_normal_execution(result) + self.assertEqual(result.output.strip(), 'Nothing to do') + def test_info(self): """rq info -u """ runner = CliRunner() @@ -69,12 +101,34 @@ class TestRQCli(RQTestCase): self.assert_normal_execution(result) self.assertIn('1 queues, 1 jobs total', result.output) + def test_info_only_queues(self): + """rq info -u --only-queues (-Q)""" + runner = CliRunner() + result = runner.invoke(main, ['info', '-u', self.redis_url, '--only-queues']) + self.assert_normal_execution(result) + self.assertIn('1 queues, 1 jobs total', result.output) + + def test_info_only_workers(self): + """rq info -u --only-workers (-W)""" + runner = CliRunner() + result = runner.invoke(main, ['info', '-u', self.redis_url, '--only-workers']) + self.assert_normal_execution(result) + self.assertIn('0 workers, 1 queues', result.output) + def test_worker(self): """rq worker -u -b""" runner = CliRunner() result = runner.invoke(main, ['worker', '-u', self.redis_url, '-b']) self.assert_normal_execution(result) + def test_worker_pid(self): + """rq worker -u /tmp/..""" + pid = self.tmpdir.join('rq.pid') + runner = CliRunner() + result = runner.invoke(main, ['worker', '-u', self.redis_url, '-b', '--pid', str(pid)]) + self.assertTrue(len(pid.read()) > 0) + self.assert_normal_execution(result) + def test_exception_handlers(self): """rq worker -u -b --exception-handler """ q = Queue() @@ -96,12 +150,20 @@ class TestRQCli(RQTestCase): def test_suspend_and_resume(self): """rq suspend -u + rq worker -u -b rq resume -u """ runner = CliRunner() result = runner.invoke(main, ['suspend', '-u', self.redis_url]) self.assert_normal_execution(result) + result = runner.invoke(main, ['worker', '-u', self.redis_url, '-b']) + self.assertEqual(result.exit_code, 1) + self.assertEqual( + result.output.strip(), + 'RQ is currently suspended, to resume job execution run "rq resume"' + ) + result = runner.invoke(main, ['resume', '-u', self.redis_url]) self.assert_normal_execution(result) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index c5319c3..4fb1c1b 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -7,6 +7,7 @@ from redis import StrictRedis from rq.decorators import job from rq.job import Job from rq.worker import DEFAULT_RESULT_TTL +from rq.queue import Queue from tests import RQTestCase from tests.fixtures import decorated_job @@ -110,3 +111,39 @@ class TestDecorator(RQTestCase): foo.delay() self.assertEqual(resolve_connection.call_count, 1) + + def test_decorator_custom_queue_class(self): + """Ensure that a custom queue class can be passed to the job decorator""" + class CustomQueue(Queue): + pass + CustomQueue.enqueue_call = mock.MagicMock( + spec=lambda *args, **kwargs: None, + name='enqueue_call' + ) + + custom_decorator = job(queue='default', queue_class=CustomQueue) + self.assertIs(custom_decorator.queue_class, CustomQueue) + + @custom_decorator + def custom_queue_class_job(x, y): + return x + y + + custom_queue_class_job.delay(1, 2) + self.assertEqual(CustomQueue.enqueue_call.call_count, 1) + + def test_decorate_custom_queue(self): + """Ensure that a custom queue instance can be passed to the job decorator""" + class CustomQueue(Queue): + pass + CustomQueue.enqueue_call = mock.MagicMock( + spec=lambda *args, **kwargs: None, + name='enqueue_call' + ) + queue = CustomQueue() + + @job(queue=queue) + def custom_queue_job(x, y): + return x + y + + custom_queue_job.delay(1, 2) + self.assertEqual(queue.enqueue_call.call_count, 1) diff --git a/tests/test_job.py b/tests/test_job.py index 1b2affe..a41a199 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -10,7 +10,7 @@ from tests.helpers import strip_microseconds from rq.compat import PY2, as_text from rq.exceptions import NoSuchJobError, UnpickleError -from rq.job import Job, get_current_job, JobStatus, cancel_job +from rq.job import Job, get_current_job, JobStatus, cancel_job, requeue_job from rq.queue import Queue, get_failed_queue from rq.registry import DeferredJobRegistry from rq.utils import utcformat @@ -46,10 +46,12 @@ class TestJob(RQTestCase): def test_create_empty_job(self): """Creation of new empty jobs.""" job = Job() + job.description = 'test job' # Jobs have a random UUID and a creation date self.assertIsNotNone(job.id) self.assertIsNotNone(job.created_at) + self.assertEqual(str(job), "" % job.id) # ...and nothing else self.assertIsNone(job.origin) @@ -68,6 +70,12 @@ class TestJob(RQTestCase): with self.assertRaises(ValueError): job.kwargs + def test_create_param_errors(self): + """Creation of jobs may result in errors""" + self.assertRaises(TypeError, Job.create, fixtures.say_hello, args="string") + self.assertRaises(TypeError, Job.create, fixtures.say_hello, kwargs="string") + self.assertRaises(TypeError, Job.create, func=42) + def test_create_typical_job(self): """Creation of jobs for function calls.""" job = Job.create(func=fixtures.some_calculation, args=(3, 4), kwargs=dict(z=2)) @@ -439,9 +447,25 @@ class TestJob(RQTestCase): def test_create_failed_and_cancel_job(self): """test creating and using cancel_job deletes job properly""" - failed = get_failed_queue(connection=self.testconn) - job = failed.enqueue(fixtures.say_hello) + failed_queue = get_failed_queue(connection=self.testconn) + job = failed_queue.enqueue(fixtures.say_hello) job.set_status(JobStatus.FAILED) - self.assertEqual(1, len(failed.get_jobs())) + self.assertEqual(1, len(failed_queue.get_jobs())) cancel_job(job.id) - self.assertEqual(0, len(failed.get_jobs())) + self.assertEqual(0, len(failed_queue.get_jobs())) + + def test_create_and_requeue_job(self): + """Requeueing existing jobs.""" + job = Job.create(func=fixtures.div_by_zero, args=(1, 2, 3)) + job.origin = 'fake' + job.save() + get_failed_queue().quarantine(job, Exception('Some fake error')) # noqa + + self.assertEqual(Queue.all(), [get_failed_queue()]) # noqa + self.assertEqual(get_failed_queue().count, 1) + + requeued_job = requeue_job(job.id) + + self.assertEqual(get_failed_queue().count, 0) + self.assertEqual(Queue('fake').count, 1) + self.assertEqual(requeued_job.origin, job.origin) diff --git a/tests/test_queue.py b/tests/test_queue.py index 75f59df..8081313 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -22,6 +22,7 @@ class TestQueue(RQTestCase): """Creating queues.""" q = Queue('my-queue') self.assertEqual(q.name, 'my-queue') + self.assertEqual(str(q), '') def test_create_default_queue(self): """Instantiating the default queue.""" @@ -38,6 +39,9 @@ class TestQueue(RQTestCase): self.assertEqual(q2, q1) self.assertNotEqual(q1, q3) self.assertNotEqual(q2, q3) + self.assertGreater(q1, q3) + self.assertRaises(TypeError, lambda: q1 == 'some string') + self.assertRaises(TypeError, lambda: q1 < 'some string') def test_empty_queue(self): """Emptying queues.""" @@ -338,6 +342,28 @@ class TestQueue(RQTestCase): # Queue.all() should still report the empty queues self.assertEqual(len(Queue.all()), 3) + def test_all_custom_job(self): + class CustomJob(Job): + pass + + q = Queue('all-queue') + q.enqueue(say_hello) + queues = Queue.all(job_class=CustomJob) + self.assertEqual(len(queues), 1) + self.assertIs(queues[0].job_class, CustomJob) + + def test_from_queue_key(self): + """Ensure being able to get a Queue instance manually from Redis""" + q = Queue() + key = Queue.redis_queue_namespace_prefix + 'default' + reverse_q = Queue.from_queue_key(key) + self.assertEqual(q, reverse_q) + + def test_from_queue_key_error(self): + """Ensure that an exception is raised if the queue prefix is wrong""" + key = 'some:weird:prefix:' + 'default' + self.assertRaises(ValueError, Queue.from_queue_key, key) + def test_enqueue_dependents(self): """Enqueueing dependent jobs pushes all jobs in the depends set to the queue and removes them from DeferredJobQueue.""" @@ -490,6 +516,16 @@ class TestQueue(RQTestCase): class TestFailedQueue(RQTestCase): + def test_get_failed_queue(self): + """Use custom job class""" + class CustomJob(Job): + pass + failed_queue = get_failed_queue(job_class=CustomJob) + self.assertIs(failed_queue.job_class, CustomJob) + + failed_queue = get_failed_queue(job_class='rq.job.Job') + self.assertIsNot(failed_queue.job_class, CustomJob) + def test_requeue_job(self): """Requeueing existing jobs.""" job = Job.create(func=div_by_zero, args=(1, 2, 3)) @@ -500,10 +536,11 @@ class TestFailedQueue(RQTestCase): self.assertEqual(Queue.all(), [get_failed_queue()]) # noqa self.assertEqual(get_failed_queue().count, 1) - get_failed_queue().requeue(job.id) + requeued_job = get_failed_queue().requeue(job.id) self.assertEqual(get_failed_queue().count, 0) self.assertEqual(Queue('fake').count, 1) + self.assertEqual(requeued_job.origin, job.origin) def test_get_job_on_failed_queue(self): default_queue = Queue() diff --git a/tests/test_registry.py b/tests/test_registry.py index c3bbb7c..726f79b 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -13,6 +13,10 @@ from tests import RQTestCase from tests.fixtures import div_by_zero, say_hello +class CustomJob(Job): + """A custom job class just to test it""" + + class TestRegistry(RQTestCase): def setUp(self): @@ -22,6 +26,10 @@ class TestRegistry(RQTestCase): def test_key(self): self.assertEqual(self.registry.key, 'rq:wip:default') + def test_custom_job_class(self): + registry = StartedJobRegistry(job_class=CustomJob) + self.assertFalse(registry.job_class == self.registry.job_class) + def test_add_and_remove(self): """Adding and removing job to StartedJobRegistry.""" timestamp = current_timestamp() @@ -86,12 +94,15 @@ class TestRegistry(RQTestCase): worker = Worker([queue]) job = queue.enqueue(say_hello) + self.assertTrue(job.is_queued) worker.prepare_job_execution(job) self.assertIn(job.id, registry.get_job_ids()) + self.assertTrue(job.is_started) worker.perform_job(job, queue) self.assertNotIn(job.id, registry.get_job_ids()) + self.assertTrue(job.is_finished) # Job that fails job = queue.enqueue(div_by_zero) diff --git a/tests/test_worker.py b/tests/test_worker.py index 4449aa8..cf117b7 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -626,7 +626,7 @@ class TimeoutTestCase: def setUp(self): # we want tests to fail if signal are ignored and the work remain # running, so set a signal to kill them after X seconds - self.killtimeout = 10 + self.killtimeout = 15 signal.signal(signal.SIGALRM, self._timeout) signal.alarm(self.killtimeout)