diff --git a/rq/__init__.py b/rq/__init__.py index 6451efb..f70e950 100644 --- a/rq/__init__.py +++ b/rq/__init__.py @@ -1,5 +1,6 @@ -from redis import Redis -from .proxy import conn +from .connections import get_current_connection +from .connections import use_connection, push_connection, pop_connection +from .connections import Connection from .queue import Queue, get_failed_queue from .job import cancel_job, requeue_job from .worker import Worker @@ -7,21 +8,12 @@ from .version import VERSION def use_redis(redis=None): - """Pushes the given Redis connection (a redis.Redis instance) onto the - connection stack. This is merely a helper function to easily start - using RQ without having to know or understand the RQ connection stack. + use_connection(redis) - When given None as an argument, a (default) Redis connection to - redis://localhost:6379 is set up. - """ - if redis is None: - redis = Redis() - elif not isinstance(redis, Redis): - raise TypeError('Argument redis should be a Redis instance.') - conn.push(redis) __all__ = [ - 'conn', 'use_redis', + 'use_connection', 'get_current_connection', + 'push_connection', 'pop_connection', 'Connection', 'Queue', 'get_failed_queue', 'Worker', 'cancel_job', 'requeue_job'] __version__ = VERSION diff --git a/rq/connections.py b/rq/connections.py new file mode 100644 index 0000000..4087d41 --- /dev/null +++ b/rq/connections.py @@ -0,0 +1,82 @@ +from contextlib import contextmanager +from redis import Redis + + +class NoRedisConnectionException(Exception): + pass + + +class _RedisConnectionStack(object): + def __init__(self): + self.stack = [] + + def _get_current_object(self): + try: + return self.stack[-1] + except IndexError: + msg = 'No Redis connection configured.' + raise NoRedisConnectionException(msg) + + def pop(self): + return self.stack.pop() + + def push(self, connection): + self.stack.append(connection) + + def empty(self): + del self.stack[:] + + def depth(self): + return len(self.stack) + + def __getattr__(self, name): + return getattr(self._get_current_object(), name) + + +_connection_stack = _RedisConnectionStack() + + +@contextmanager +def Connection(connection=None): + if connection is None: + connection = Redis() + _connection_stack.push(connection) + try: + yield + finally: + popped = _connection_stack.pop() + assert popped == connection, \ + 'Unexpected Redis connection was popped off the stack. ' \ + 'Check your Redis connection setup.' + + +def push_connection(redis): + """Pushes the given connection on the stack.""" + _connection_stack.push(redis) + + +def pop_connection(): + """Pops the topmost connection from the stack.""" + return _connection_stack.pop() + + +def use_connection(redis): + """Clears the stack and uses the given connection. Protects against mixed + use of use_connection() and stacked connection contexts. + """ + assert _connection_stack.depth() <= 1, \ + 'You should not mix Connection contexts with use_connection().' + _connection_stack.empty() + push_connection(redis) + + +def get_current_connection(): + """Returns the current Redis connection (i.e. the topmost on the + connection stack). + """ + return _connection_stack._get_current_object() + + +__all__ = ['Connection', + 'get_current_connection', 'push_connection', 'pop_connection', + 'use_connection'] diff --git a/rq/job.py b/rq/job.py index ecd84bb..856890f 100644 --- a/rq/job.py +++ b/rq/job.py @@ -2,7 +2,7 @@ import importlib import times from uuid import uuid4 from cPickle import loads, dumps, UnpicklingError -from .proxy import conn +from .connections import get_current_connection from .exceptions import UnpickleError, NoSuchJobError @@ -21,20 +21,20 @@ def unpickle(pickled_string): return obj -def cancel_job(job_id): +def cancel_job(job_id, connection=None): """Cancels the job with the given job ID, preventing execution. Discards any job info (i.e. it can't be requeued later). """ - Job(job_id).cancel() + Job(job_id, connection=connection).cancel() -def requeue_job(job_id): +def requeue_job(job_id, connection=None): """Requeues the job with the given job ID. The job ID should refer to a failed job (i.e. it should be on the failed queue). If no such (failed) job exists, a NoSuchJobError is raised. """ from .queue import get_failed_queue - fq = get_failed_queue() + fq = get_failed_queue(connection=connection) fq.requeue(job_id) @@ -48,7 +48,8 @@ class Job(object): """Creates a new Job instance for the given function, arguments, and keyword arguments. """ - job = Job() + connection = kwargs.pop('connection', None) + job = Job(connection=connection) job._func_name = '%s.%s' % (func.__module__, func.__name__) job._args = args job._kwargs = kwargs @@ -80,18 +81,22 @@ class Job(object): @classmethod def exists(cls, job_id): """Returns whether a job hash exists for the given job ID.""" + conn = get_current_connection() return conn.exists(cls.key_for(job_id)) @classmethod - def fetch(cls, id): + def fetch(cls, id, connection=None): """Fetches a persisted job from its corresponding Redis key and instantiates it. """ - job = Job(id) + job = Job(id, connection=connection) job.refresh() return job - def __init__(self, id=None): + def __init__(self, id=None, connection=None): + if connection is None: + connection = get_current_connection() + self.connection = connection self._id = id self.created_at = times.now() self._func_name = None @@ -156,7 +161,7 @@ class Job(object): seconds by default). """ if self._result is None: - rv = conn.hget(self.key, 'result') + rv = self.connection.hget(self.key, 'result') if rv is not None: # cache the result self._result = loads(rv) @@ -175,7 +180,7 @@ class Job(object): 'enqueued_at', 'ended_at', 'result', 'exc_info', 'timeout'] data, created_at, origin, description, \ enqueued_at, ended_at, result, \ - exc_info, timeout = conn.hmget(key, properties) + exc_info, timeout = self.connection.hmget(key, properties) if data is None: raise NoSuchJobError('No such job: %s' % (key,)) @@ -222,7 +227,7 @@ class Job(object): if self.timeout is not None: obj['timeout'] = self.timeout - conn.hmset(key, obj) + self.connection.hmset(key, obj) def cancel(self): """Cancels the given job, which will prevent the job from ever being @@ -237,7 +242,7 @@ class Job(object): def delete(self): """Deletes the job hash from Redis.""" - conn.delete(self.key) + self.connection.delete(self.key) # Job execution diff --git a/rq/proxy.py b/rq/proxy.py deleted file mode 100644 index f075767..0000000 --- a/rq/proxy.py +++ /dev/null @@ -1,28 +0,0 @@ -class NoRedisConnectionException(Exception): - pass - - -class RedisConnectionProxy(object): - def __init__(self): - self.stack = [] - - def _get_current_object(self): - try: - return self.stack[-1] - except IndexError: - msg = 'No Redis connection configured.' - raise NoRedisConnectionException(msg) - - def pop(self): - return self.stack.pop() - - def push(self, db): - self.stack.append(db) - - def __getattr__(self, name): - return getattr(self._get_current_object(), name) - - -conn = RedisConnectionProxy() - -__all__ = ['conn'] diff --git a/rq/queue.py b/rq/queue.py index 3ac4071..f359595 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -1,13 +1,13 @@ import times from functools import total_ordering -from .proxy import conn +from .connections import get_current_connection from .job import Job from .exceptions import NoSuchJobError, UnpickleError, InvalidJobOperationError -def get_failed_queue(): +def get_failed_queue(connection=None): """Returns a handle to the special failed queue.""" - return FailedQueue() + return FailedQueue(connection=connection) def compact(lst): @@ -19,14 +19,19 @@ class Queue(object): redis_queue_namespace_prefix = 'rq:queue:' @classmethod - def all(cls): + def all(cls, connection=None): """Returns an iterable of all Queues. """ prefix = cls.redis_queue_namespace_prefix - return map(cls.from_queue_key, conn.keys('%s*' % prefix)) + if connection is None: + connection = get_current_connection() + + def to_queue(queue_key): + return cls.from_queue_key(queue_key, connection=connection) + return map(to_queue, connection.keys('%s*' % prefix)) @classmethod - def from_queue_key(cls, queue_key): + def from_queue_key(cls, queue_key, connection=None): """Returns a Queue instance, based on the naming conventions for naming the internal Redis keys. Can be used to reverse-lookup Queues by their Redis keys. @@ -35,9 +40,12 @@ class Queue(object): if not queue_key.startswith(prefix): raise ValueError('Not a valid RQ queue key: %s' % (queue_key,)) name = queue_key[len(prefix):] - return Queue(name) + return Queue(name, connection=connection) - def __init__(self, name='default', default_timeout=None): + def __init__(self, name='default', default_timeout=None, connection=None): + if connection is None: + connection = get_current_connection() + self.connection = connection prefix = self.redis_queue_namespace_prefix self.name = name self._key = '%s%s' % (prefix, name) @@ -50,7 +58,7 @@ class Queue(object): def empty(self): """Removes all messages on the queue.""" - conn.delete(self.key) + self.connection.delete(self.key) def is_empty(self): """Returns whether the current queue is empty.""" @@ -59,7 +67,7 @@ class Queue(object): @property def job_ids(self): """Returns a list of all job IDS in the queue.""" - return conn.lrange(self.key, 0, -1) + return self.connection.lrange(self.key, 0, -1) @property def jobs(self): @@ -78,7 +86,7 @@ class Queue(object): @property def count(self): """Returns a count of all messages in the queue.""" - return conn.llen(self.key) + return self.connection.llen(self.key) def compact(self): """Removes all "dead" jobs from the queue by cycling through it, while @@ -86,18 +94,18 @@ class Queue(object): """ COMPACT_QUEUE = 'rq:queue:_compact' - conn.rename(self.key, COMPACT_QUEUE) + self.connection.rename(self.key, COMPACT_QUEUE) while True: - job_id = conn.lpop(COMPACT_QUEUE) + job_id = self.connection.lpop(COMPACT_QUEUE) if job_id is None: break if Job.exists(job_id): - conn.rpush(self.key, job_id) + self.connection.rpush(self.key, job_id) def push_job_id(self, job_id): # noqa """Pushes a job ID on the corresponding Redis queue.""" - conn.rpush(self.key, job_id) + self.connection.rpush(self.key, job_id) def enqueue(self, f, *args, **kwargs): """Creates a job to represent the delayed function call and enqueues @@ -115,7 +123,7 @@ class Queue(object): 'by workers.') timeout = kwargs.pop('timeout', None) - job = Job.create(f, *args, **kwargs) + job = Job.create(f, *args, connection=self.connection, **kwargs) return self.enqueue_job(job, timeout=timeout) def enqueue_job(self, job, timeout=None, set_meta_data=True): @@ -143,7 +151,7 @@ class Queue(object): def pop_job_id(self): """Pops a given job ID from this Redis queue.""" - return conn.lpop(self.key) + return self.connection.lpop(self.key) @classmethod def lpop(cls, queue_keys, blocking): @@ -155,6 +163,7 @@ class Queue(object): Until Redis receives a specific method for this, we'll have to wrap it this way. """ + conn = get_current_connection() if blocking: queue_key, job_id = conn.blpop(queue_keys) return queue_key, job_id @@ -174,7 +183,7 @@ class Queue(object): if job_id is None: return None try: - job = Job.fetch(job_id) + job = Job.fetch(job_id, connection=self.connection) except NoSuchJobError as e: # Silently pass on jobs that don't exist (anymore), # and continue by reinvoking itself recursively @@ -187,7 +196,7 @@ class Queue(object): return job @classmethod - def dequeue_any(cls, queues, blocking): + def dequeue_any(cls, queues, blocking, connection=None): """Class method returning the Job instance at the front of the given set of Queues, where the order of the queues is important. @@ -200,13 +209,13 @@ class Queue(object): if result is None: return None queue_key, job_id = result - queue = Queue.from_queue_key(queue_key) + queue = Queue.from_queue_key(queue_key, connection=connection) try: - job = Job.fetch(job_id) + job = Job.fetch(job_id, connection=connection) except NoSuchJobError: # Silently pass on jobs that don't exist (anymore), # and continue by reinvoking the same function recursively - return cls.dequeue_any(queues, blocking) + return cls.dequeue_any(queues, blocking, connection=connection) except UnpickleError as e: # Attach queue information on the exception for improved error # reporting @@ -240,8 +249,8 @@ class Queue(object): class FailedQueue(Queue): - def __init__(self): - super(FailedQueue, self).__init__('failed') + def __init__(self, connection=None): + super(FailedQueue, self).__init__('failed', connection=connection) def quarantine(self, job, exc_info): """Puts the given Job in quarantine (i.e. put it on the failed @@ -258,16 +267,16 @@ class FailedQueue(Queue): def requeue(self, job_id): """Requeues the job with the given job ID.""" try: - job = Job.fetch(job_id) + job = Job.fetch(job_id, connection=self.connection) except NoSuchJobError: # Silently ignore/remove this job and return (i.e. do nothing) - conn.lrem(self.key, job_id) + self.connection.lrem(self.key, job_id) return - # Delete it from the FailedQueue (raise an error if that failed) - if conn.lrem(self.key, job.id) == 0: + # Delete it from the failed queue (raise an error if that failed) + if self.connection.lrem(self.key, job.id) == 0: raise InvalidJobOperationError('Cannot requeue non-failed jobs.') job.exc_info = None - q = Queue(job.origin) + q = Queue(job.origin, connection=self.connection) q.enqueue_job(job) diff --git a/rq/worker.py b/rq/worker.py index 92da8b0..ce368f1 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -12,8 +12,8 @@ try: Logger = Logger # Does nothing except it shuts up pyflakes annoying error except ImportError: from logging import Logger -from .queue import Queue, FailedQueue -from .proxy import conn +from .queue import Queue, get_failed_queue +from .connections import get_current_connection from .utils import make_colorizer from .exceptions import NoQueueError, UnpickleError from .timeouts import death_pentalty_after @@ -53,6 +53,7 @@ class Worker(object): def all(cls): """Returns an iterable of all Workers. """ + conn = get_current_connection() reported_working = conn.smembers(cls.redis_workers_keys) return compact(map(cls.find_by_key, reported_working)) @@ -67,6 +68,7 @@ class Worker(object): if not worker_key.startswith(prefix): raise ValueError('Not a valid RQ worker key: %s' % (worker_key,)) + conn = get_current_connection() if not conn.exists(worker_key): return None @@ -79,7 +81,10 @@ class Worker(object): return worker - def __init__(self, queues, name=None, rv_ttl=500): # noqa + def __init__(self, queues, name=None, rv_ttl=500, connection=None): # noqa + if connection is None: + connection = get_current_connection() + self.connection = connection if isinstance(queues, Queue): queues = [queues] self._name = name @@ -91,7 +96,7 @@ class Worker(object): self._horse_pid = 0 self._stopped = False self.log = Logger('worker') - self.failed_queue = FailedQueue() + self.failed_queue = get_failed_queue(connection=self.connection) def validate_queues(self): # noqa @@ -158,14 +163,15 @@ class Worker(object): def register_birth(self): # noqa """Registers its own birth.""" self.log.debug('Registering birth of worker %s' % (self.name,)) - if conn.exists(self.key) and not conn.hexists(self.key, 'death'): + if self.connection.exists(self.key) and \ + not self.connection.hexists(self.key, 'death'): raise ValueError( 'There exists an active worker named \'%s\' ' 'already.' % (self.name,)) key = self.key now = time.time() queues = ','.join(self.queue_names()) - with conn.pipeline() as p: + with self.connection.pipeline() as p: p.delete(key) p.hset(key, 'birth', now) p.hset(key, 'queues', queues) @@ -175,7 +181,7 @@ class Worker(object): def register_death(self): """Registers its own death.""" self.log.debug('Registering death') - with conn.pipeline() as p: + with self.connection.pipeline() as p: # We cannot use self.state = 'dead' here, because that would # rollback the pipeline p.srem(self.redis_workers_keys, self.key) @@ -185,7 +191,7 @@ class Worker(object): def set_state(self, new_state): self._state = new_state - conn.hset(self.key, 'state', new_state) + self.connection.hset(self.key, 'state', new_state) def get_state(self): return self._state @@ -268,7 +274,8 @@ class Worker(object): green(', '.join(qnames))) wait_for_job = not burst try: - result = Queue.dequeue_any(self.queues, wait_for_job) + result = Queue.dequeue_any(self.queues, wait_for_job, \ + connection=self.connection) if result is None: break except UnpickleError as e: @@ -359,7 +366,7 @@ class Worker(object): self.log.info('Job OK, result = %s' % (yellow(unicode(rv)),)) if rv is not None: - p = conn.pipeline() + p = self.connection.pipeline() p.hset(job.key, 'result', dumps(rv)) p.expire(job.key, self.rv_ttl) p.execute() diff --git a/tests/__init__.py b/tests/__init__.py index 1c24bba..c46f6ef 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,7 @@ import unittest from redis import Redis from logbook import NullHandler -from rq import conn +from rq import push_connection, pop_connection def find_empty_redis_database(): @@ -42,7 +42,7 @@ class RQTestCase(unittest.TestCase): def setUpClass(cls): # Set up connection to Redis testconn = find_empty_redis_database() - conn.push(testconn) + push_connection(testconn) # Store the connection (for sanity checking) cls.testconn = testconn @@ -53,17 +53,17 @@ class RQTestCase(unittest.TestCase): def setUp(self): # Flush beforewards (we like our hygiene) - conn.flushdb() + self.testconn.flushdb() def tearDown(self): # Flush afterwards - conn.flushdb() + self.testconn.flushdb() @classmethod def tearDownClass(cls): cls.log_handler.pop_thread() # Pop the connection to Redis - testconn = conn.pop() + testconn = pop_connection() assert testconn == cls.testconn, 'Wow, something really nasty ' \ 'happened to the Redis connection stack. Check your setup.' diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..e0d149d --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,36 @@ +from tests import RQTestCase, find_empty_redis_database +from tests.fixtures import do_nothing +from rq import Queue +from rq import Connection + + +def new_connection(): + return find_empty_redis_database() + + +class TestConnectionInheritance(RQTestCase): + def test_connection_detection(self): + """Automatic detection of the connection.""" + q = Queue() + self.assertEquals(q.connection, self.testconn) + + def test_connection_stacking(self): + """Connection stacking.""" + conn1 = new_connection() + conn2 = new_connection() + + with Connection(conn1): + q1 = Queue() + with Connection(conn2): + q2 = Queue() + self.assertNotEquals(q1.connection, q2.connection) + + def test_connection_pass_thru(self): + """Connection passed through from queues to jobs.""" + q1 = Queue() + with Connection(new_connection()): + q2 = Queue() + job1 = q1.enqueue(do_nothing) + job2 = q2.enqueue(do_nothing) + self.assertEquals(q1.connection, job1.connection) + self.assertEquals(q2.connection, job2.connection) diff --git a/tests/test_queue.py b/tests/test_queue.py index ea0e74e..a75abf5 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -188,7 +188,7 @@ class TestQueue(RQTestCase): # Dequeue simply ignores the missing job and returns None self.assertEquals(q.count, 1) - self.assertEquals(Queue.dequeue_any([Queue(), Queue('low')], False), + self.assertEquals(Queue.dequeue_any([Queue(), Queue('low')], False), # noqa None) self.assertEquals(q.count, 0) @@ -199,9 +199,9 @@ class TestFailedQueue(RQTestCase): job = Job.create(div_by_zero, 1, 2, 3) job.origin = 'fake' job.save() - get_failed_queue().quarantine(job, Exception('Some fake error')) + get_failed_queue().quarantine(job, Exception('Some fake error')) # noqa - self.assertItemsEqual(Queue.all(), [get_failed_queue()]) + self.assertItemsEqual(Queue.all(), [get_failed_queue()]) # noqa self.assertEquals(get_failed_queue().count, 1) get_failed_queue().requeue(job.id) diff --git a/tests/test_worker.py b/tests/test_worker.py index 02d2fdc..27f3b25 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -3,7 +3,7 @@ from tests import RQTestCase, slow from tests.fixtures import say_hello, div_by_zero, do_nothing, create_file, \ create_file_after_timeout from tests.helpers import strip_milliseconds -from rq import Queue, Worker +from rq import Queue, Worker, get_failed_queue from rq.job import Job @@ -28,7 +28,7 @@ class TestWorker(RQTestCase): def test_work_is_unreadable(self): """Unreadable jobs are put on the failed queue.""" q = Queue() - failed_q = Queue('failed') + failed_q = get_failed_queue() self.assertEquals(failed_q.count, 0) self.assertEquals(q.count, 0) @@ -58,7 +58,7 @@ class TestWorker(RQTestCase): def test_work_fails(self): """Failing jobs are put on the failed queue.""" q = Queue() - failed_q = Queue('failed') + failed_q = get_failed_queue() # Preconditions self.assertEquals(failed_q.count, 0)