[Hotfix] Fix SSL connection for scheduler (#1894)

* fix: ssl

* fix: reinstate a test for parse_connection
main
Cyril Chapellier 2 years ago committed by GitHub
parent 07fef85dd2
commit a228b4838c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -118,23 +118,10 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]: def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
connection_kwargs = connection.connection_pool.connection_kwargs.copy() connection_pool_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
connection_kwargs.pop('parser_class', None)
connection_pool_class = connection.connection_pool.connection_class connection_pool_class = connection.connection_pool.connection_class
if issubclass(connection_pool_class, SSLConnection):
connection_kwargs['ssl'] = True return connection.__class__, connection_pool_class, connection_pool_kwargs
if issubclass(connection_pool_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
return connection.__class__, connection_pool_class, connection_kwargs
_connection_stack = LocalStack() _connection_stack = LocalStack()

@ -50,7 +50,7 @@ class RQScheduler:
self._acquired_locks: Set[str] = set() self._acquired_locks: Set[str] = set()
self._scheduled_job_registries: List[ScheduledJobRegistry] = [] self._scheduled_job_registries: List[ScheduledJobRegistry] = []
self.lock_acquisition_time = None self.lock_acquisition_time = None
self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection) self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
self.serializer = resolve_serializer(serializer) self.serializer = resolve_serializer(serializer)
self._connection = None self._connection = None
@ -71,7 +71,7 @@ class RQScheduler:
if self._connection: if self._connection:
return self._connection return self._connection
self._connection = self._connection_class( self._connection = self._connection_class(
connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs) connection_pool=ConnectionPool(connection_class=self._pool_class, **self._pool_kwargs)
) )
return self._connection return self._connection

@ -11,7 +11,7 @@ from typing import Dict, List, NamedTuple, Optional, Set, Type, Union
from uuid import uuid4 from uuid import uuid4
from redis import Redis from redis import Redis
from redis import SSLConnection, UnixDomainSocketConnection from redis import ConnectionPool
from rq.serializers import DefaultSerializer from rq.serializers import DefaultSerializer
from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty
@ -65,7 +65,7 @@ class WorkerPool:
# A dictionary of WorkerData keyed by worker name # A dictionary of WorkerData keyed by worker name
self.worker_dict: Dict[str, WorkerData] = {} self.worker_dict: Dict[str, WorkerData] = {}
self._connection_class, _, self._connection_kwargs = parse_connection(connection) self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
@property @property
def queues(self) -> List[Queue]: def queues(self) -> List[Queue]:
@ -158,7 +158,7 @@ class WorkerPool:
name = uuid4().hex name = uuid4().hex
process = Process( process = Process(
target=run_worker, target=run_worker,
args=(name, self._queue_names, self._connection_class, self._connection_kwargs), args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs),
kwargs={ kwargs={
'_sleep': _sleep, '_sleep': _sleep,
'burst': burst, 'burst': burst,
@ -234,7 +234,8 @@ def run_worker(
worker_name: str, worker_name: str,
queue_names: List[str], queue_names: List[str],
connection_class, connection_class,
connection_kwargs: dict, connection_pool_class,
connection_pool_kwargs: dict,
worker_class: Type[BaseWorker] = Worker, worker_class: Type[BaseWorker] = Worker,
serializer: Type[DefaultSerializer] = DefaultSerializer, serializer: Type[DefaultSerializer] = DefaultSerializer,
job_class: Type[Job] = Job, job_class: Type[Job] = Job,
@ -242,7 +243,9 @@ def run_worker(
logging_level: str = "INFO", logging_level: str = "INFO",
_sleep: int = 0, _sleep: int = 0,
): ):
connection = connection_class(**connection_kwargs) connection = connection_class(
connection_pool=ConnectionPool(connection_class=connection_pool_class, **connection_pool_kwargs)
)
queues = [Queue(name, connection=connection) for name in queue_names] queues = [Queue(name, connection=connection) for name in queue_names]
worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class) worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class)
worker.log.info("Starting worker started with PID %s", os.getpid()) worker.log.info("Starting worker started with PID %s", os.getpid())

@ -1,4 +1,4 @@
from redis import ConnectionPool, Redis, UnixDomainSocketConnection from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from rq import Connection, Queue from rq import Connection, Queue
from rq.connections import parse_connection from rq.connections import parse_connection
@ -38,10 +38,14 @@ class TestConnectionInheritance(RQTestCase):
self.assertEqual(q2.connection, job2.connection) self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self): def test_parse_connection(self):
"""Test parsing `ssl` and UnixDomainSocketConnection""" """Test parsing the connection"""
_, _, kwargs = parse_connection(Redis(ssl=True)) conn_class, pool_class, pool_kwargs = parse_connection(Redis(ssl=True))
self.assertTrue(kwargs['ssl']) self.assertEqual(conn_class, Redis)
self.assertEqual(pool_class, SSLConnection)
path = '/tmp/redis.sock' path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path) pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
_, _, kwargs = parse_connection(Redis(connection_pool=pool)) conn_class, pool_class, pool_kwargs = parse_connection(Redis(connection_pool=pool))
self.assertTrue(kwargs['unix_socket_path'], path) self.assertEqual(conn_class, Redis)
self.assertEqual(pool_class, UnixDomainSocketConnection)
self.assertEqual(pool_kwargs, {"path": path})

@ -8,6 +8,7 @@ from rq.job import JobStatus
from tests import TestCase from tests import TestCase
from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello
from rq.connections import parse_connection
from rq.queue import Queue from rq.queue import Queue
from rq.serializers import JSONSerializer from rq.serializers import JSONSerializer
from rq.worker import SimpleWorker from rq.worker import SimpleWorker
@ -108,8 +109,10 @@ class TestWorkerPool(TestCase):
"""Ensure run_worker() properly spawns a Worker""" """Ensure run_worker() properly spawns a Worker"""
queue = Queue('foo', connection=self.connection) queue = Queue('foo', connection=self.connection)
queue.enqueue(say_hello) queue.enqueue(say_hello)
connection_class, pool_class, pool_kwargs = parse_connection(self.connection)
run_worker( run_worker(
'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy() 'test-worker', ['foo'], connection_class, pool_class, pool_kwargs
) )
# Worker should have processed the job # Worker should have processed the job
self.assertEqual(len(queue), 0) self.assertEqual(len(queue), 0)

Loading…
Cancel
Save