Added parse_connection function (#1884)

* Added parse_connection function

* feat: allow custom connection pool class (#1885)

* Added test for SSL

---------

Co-authored-by: Cyril Chapellier <tchapi@users.noreply.github.com>
main
Selwin Ong 2 years ago committed by GitHub
parent 95983cfcac
commit 77e926c424
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,8 @@
import warnings
from contextlib import contextmanager
from typing import Optional
from typing import Any, Optional, Tuple, Type
from redis import Redis
from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection
from .local import LocalStack
@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa
yield
finally:
popped = pop_connection()
assert popped == connection, (
'Unexpected Redis connection was popped off the stack. '
'Check your Redis connection setup.'
)
assert (
popped == connection
), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.'
def push_connection(redis: 'Redis'):
@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
return connection
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
connection_kwargs.pop('parser_class', None)
connection_pool_class = connection.connection_pool.connection_class
if issubclass(connection_pool_class, SSLConnection):
connection_kwargs['ssl'] = True
if issubclass(connection_pool_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
return connection.__class__, connection_pool_class, connection_kwargs
_connection_stack = LocalStack()
__all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection']

@ -7,8 +7,9 @@ from datetime import datetime
from enum import Enum
from multiprocessing import Process
from redis import SSLConnection, UnixDomainSocketConnection
from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from .connections import parse_connection
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD
from .job import Job
from .logutils import setup_loghandlers
@ -37,7 +38,7 @@ class RQScheduler:
def __init__(
self,
queues,
connection,
connection: Redis,
interval=1,
logging_level=logging.INFO,
date_format=DEFAULT_LOGGING_DATE_FORMAT,
@ -48,24 +49,11 @@ class RQScheduler:
self._acquired_locks = set()
self._scheduled_job_registries = []
self.lock_acquisition_time = None
# Copy the connection kwargs before mutating them in order to not change the arguments
# used by the current connection pool to create new connections
self._connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
self._connection_kwargs.pop('parser_class', None)
self._connection_class = connection.__class__ # client
connection_class = connection.connection_pool.connection_class
if issubclass(connection_class, SSLConnection):
self._connection_kwargs['ssl'] = True
if issubclass(connection_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path')
(
self._connection_class,
self._connection_pool_class,
self._connection_kwargs,
) = parse_connection(connection)
self.serializer = resolve_serializer(serializer)
self._connection = None
@ -85,7 +73,12 @@ class RQScheduler:
def connection(self):
if self._connection:
return self._connection
self._connection = self._connection_class(**self._connection_kwargs)
self._connection = self._connection_class(
connection_pool=ConnectionPool(
connection_class=self._connection_pool_class,
**self._connection_kwargs
)
)
return self._connection
@property

@ -1,6 +1,7 @@
from redis import Redis
from redis import ConnectionPool, Redis, UnixDomainSocketConnection
from rq import Connection, Queue
from rq.connections import parse_connection
from tests import RQTestCase, find_empty_redis_database
from tests.fixtures import do_nothing
@ -35,3 +36,12 @@ class TestConnectionInheritance(RQTestCase):
job2 = q2.enqueue(do_nothing)
self.assertEqual(q1.connection, job1.connection)
self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self):
"""Test parsing `ssl` and UnixDomainSocketConnection"""
_, _, kwargs = parse_connection(Redis(ssl=True))
self.assertTrue(kwargs['ssl'])
path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
_, _, kwargs = parse_connection(Redis(connection_pool=pool))
self.assertTrue(kwargs['unix_socket_path'], path)

@ -1,4 +1,6 @@
import os
import redis
from datetime import datetime, timedelta, timezone
from multiprocessing import Process
from unittest import mock
@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test
from .fixtures import kill_worker, say_hello
class CustomRedisConnection(redis.Connection):
"""Custom redis connection with a custom arg, used in test_custom_connection_pool"""
def __init__(self, *args, custom_arg=None, **kwargs):
self.custom_arg = custom_arg
super().__init__(*args, **kwargs)
def get_custom_arg(self):
return self.custom_arg
class TestScheduledJobRegistry(RQTestCase):
def test_get_jobs_to_enqueue(self):
@ -431,3 +444,34 @@ class TestQueue(RQTestCase):
job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2]))
self.assertEqual(job.retries_left, 3)
self.assertEqual(job.retry_intervals, [2])
def test_custom_connection_pool(self):
"""Connection pool customizing. Ensure that we can properly set a
custom connection pool class and pass extra arguments"""
custom_conn = redis.Redis(
connection_pool=redis.ConnectionPool(
connection_class=CustomRedisConnection,
db=4,
custom_arg="foo",
)
)
queue = Queue(connection=custom_conn)
scheduler = RQScheduler([queue], connection=custom_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, CustomRedisConnection)
self.assertEqual(scheduler_connection.get_custom_arg(), "foo")
def test_no_custom_connection_pool(self):
"""Connection pool customizing must not interfere if we're using a standard
connection (non-pooled)"""
standard_conn = redis.Redis(db=5)
queue = Queue(connection=standard_conn)
scheduler = RQScheduler([queue], connection=standard_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, redis.Connection)

Loading…
Cancel
Save