Added parse_connection function (#1884)

* Added parse_connection function

* feat: allow custom connection pool class (#1885)

* Added test for SSL

---------

Co-authored-by: Cyril Chapellier <tchapi@users.noreply.github.com>
main
Selwin Ong 2 years ago committed by GitHub
parent 95983cfcac
commit 77e926c424
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,8 @@
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Any, Optional, Tuple, Type
from redis import Redis from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection
from .local import LocalStack from .local import LocalStack
@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa
yield yield
finally: finally:
popped = pop_connection() popped = pop_connection()
assert popped == connection, ( assert (
'Unexpected Redis connection was popped off the stack. ' popped == connection
'Check your Redis connection setup.' ), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.'
)
def push_connection(redis: 'Redis'): def push_connection(redis: 'Redis'):
@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
return connection return connection
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
connection_kwargs.pop('parser_class', None)
connection_pool_class = connection.connection_pool.connection_class
if issubclass(connection_pool_class, SSLConnection):
connection_kwargs['ssl'] = True
if issubclass(connection_pool_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
return connection.__class__, connection_pool_class, connection_kwargs
_connection_stack = LocalStack() _connection_stack = LocalStack()
__all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection'] __all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection']

@ -7,8 +7,9 @@ from datetime import datetime
from enum import Enum from enum import Enum
from multiprocessing import Process from multiprocessing import Process
from redis import SSLConnection, UnixDomainSocketConnection from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from .connections import parse_connection
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD
from .job import Job from .job import Job
from .logutils import setup_loghandlers from .logutils import setup_loghandlers
@ -35,37 +36,24 @@ class RQScheduler:
Status = SchedulerStatus Status = SchedulerStatus
def __init__( def __init__(
self, self,
queues, queues,
connection, connection: Redis,
interval=1, interval=1,
logging_level=logging.INFO, logging_level=logging.INFO,
date_format=DEFAULT_LOGGING_DATE_FORMAT, date_format=DEFAULT_LOGGING_DATE_FORMAT,
log_format=DEFAULT_LOGGING_FORMAT, log_format=DEFAULT_LOGGING_FORMAT,
serializer=None, serializer=None,
): ):
self._queue_names = set(parse_names(queues)) self._queue_names = set(parse_names(queues))
self._acquired_locks = set() self._acquired_locks = set()
self._scheduled_job_registries = [] self._scheduled_job_registries = []
self.lock_acquisition_time = None self.lock_acquisition_time = None
# Copy the connection kwargs before mutating them in order to not change the arguments (
# used by the current connection pool to create new connections self._connection_class,
self._connection_kwargs = connection.connection_pool.connection_kwargs.copy() self._connection_pool_class,
# Redis does not accept parser_class argument which is sometimes present self._connection_kwargs,
# on connection_pool kwargs, for example when hiredis is used ) = parse_connection(connection)
self._connection_kwargs.pop('parser_class', None)
self._connection_class = connection.__class__ # client
connection_class = connection.connection_pool.connection_class
if issubclass(connection_class, SSLConnection):
self._connection_kwargs['ssl'] = True
if issubclass(connection_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path')
self.serializer = resolve_serializer(serializer) self.serializer = resolve_serializer(serializer)
self._connection = None self._connection = None
@ -85,7 +73,12 @@ class RQScheduler:
def connection(self): def connection(self):
if self._connection: if self._connection:
return self._connection return self._connection
self._connection = self._connection_class(**self._connection_kwargs) self._connection = self._connection_class(
connection_pool=ConnectionPool(
connection_class=self._connection_pool_class,
**self._connection_kwargs
)
)
return self._connection return self._connection
@property @property

@ -1,6 +1,7 @@
from redis import Redis from redis import ConnectionPool, Redis, UnixDomainSocketConnection
from rq import Connection, Queue from rq import Connection, Queue
from rq.connections import parse_connection
from tests import RQTestCase, find_empty_redis_database from tests import RQTestCase, find_empty_redis_database
from tests.fixtures import do_nothing from tests.fixtures import do_nothing
@ -35,3 +36,12 @@ class TestConnectionInheritance(RQTestCase):
job2 = q2.enqueue(do_nothing) job2 = q2.enqueue(do_nothing)
self.assertEqual(q1.connection, job1.connection) self.assertEqual(q1.connection, job1.connection)
self.assertEqual(q2.connection, job2.connection) self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self):
"""Test parsing `ssl` and UnixDomainSocketConnection"""
_, _, kwargs = parse_connection(Redis(ssl=True))
self.assertTrue(kwargs['ssl'])
path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
_, _, kwargs = parse_connection(Redis(connection_pool=pool))
self.assertTrue(kwargs['unix_socket_path'], path)

@ -1,4 +1,6 @@
import os import os
import redis
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from multiprocessing import Process from multiprocessing import Process
from unittest import mock from unittest import mock
@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test
from .fixtures import kill_worker, say_hello from .fixtures import kill_worker, say_hello
class CustomRedisConnection(redis.Connection):
"""Custom redis connection with a custom arg, used in test_custom_connection_pool"""
def __init__(self, *args, custom_arg=None, **kwargs):
self.custom_arg = custom_arg
super().__init__(*args, **kwargs)
def get_custom_arg(self):
return self.custom_arg
class TestScheduledJobRegistry(RQTestCase): class TestScheduledJobRegistry(RQTestCase):
def test_get_jobs_to_enqueue(self): def test_get_jobs_to_enqueue(self):
@ -431,3 +444,34 @@ class TestQueue(RQTestCase):
job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2])) job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2]))
self.assertEqual(job.retries_left, 3) self.assertEqual(job.retries_left, 3)
self.assertEqual(job.retry_intervals, [2]) self.assertEqual(job.retry_intervals, [2])
def test_custom_connection_pool(self):
"""Connection pool customizing. Ensure that we can properly set a
custom connection pool class and pass extra arguments"""
custom_conn = redis.Redis(
connection_pool=redis.ConnectionPool(
connection_class=CustomRedisConnection,
db=4,
custom_arg="foo",
)
)
queue = Queue(connection=custom_conn)
scheduler = RQScheduler([queue], connection=custom_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, CustomRedisConnection)
self.assertEqual(scheduler_connection.get_custom_arg(), "foo")
def test_no_custom_connection_pool(self):
"""Connection pool customizing must not interfere if we're using a standard
connection (non-pooled)"""
standard_conn = redis.Redis(db=5)
queue = Queue(connection=standard_conn)
scheduler = RQScheduler([queue], connection=standard_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, redis.Connection)

Loading…
Cancel
Save