From a228b4838c2ac84c7fca31a1908800edee4b24ed Mon Sep 17 00:00:00 2001 From: Cyril Chapellier Date: Fri, 5 May 2023 09:25:20 +0200 Subject: [PATCH] [Hotfix] Fix SSL connection for scheduler (#1894) * fix: ssl * fix: reinstate a test for parse_connection --- rq/connections.py | 19 +++---------------- rq/scheduler.py | 4 ++-- rq/worker_pool.py | 13 ++++++++----- tests/test_connection.py | 16 ++++++++++------ tests/test_worker_pool.py | 5 ++++- 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/rq/connections.py b/rq/connections.py index dfb590a..36d771d 100644 --- a/rq/connections.py +++ b/rq/connections.py @@ -118,23 +118,10 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis': def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]: - connection_kwargs = connection.connection_pool.connection_kwargs.copy() - # Redis does not accept parser_class argument which is sometimes present - # on connection_pool kwargs, for example when hiredis is used - connection_kwargs.pop('parser_class', None) + connection_pool_kwargs = connection.connection_pool.connection_kwargs.copy() connection_pool_class = connection.connection_pool.connection_class - if issubclass(connection_pool_class, SSLConnection): - connection_kwargs['ssl'] = True - if issubclass(connection_pool_class, UnixDomainSocketConnection): - # The connection keyword arguments are obtained from - # `UnixDomainSocketConnection`, which expects `path`, but passed to - # `redis.client.Redis`, which expects `unix_socket_path`, renaming - # the key is necessary. - # `path` is not left in the dictionary as that keyword argument is - # not expected by `redis.client.Redis` and would raise an exception. - connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path') - - return connection.__class__, connection_pool_class, connection_kwargs + + return connection.__class__, connection_pool_class, connection_pool_kwargs _connection_stack = LocalStack() diff --git a/rq/scheduler.py b/rq/scheduler.py index 069181d..a64b400 100644 --- a/rq/scheduler.py +++ b/rq/scheduler.py @@ -50,7 +50,7 @@ class RQScheduler: self._acquired_locks: Set[str] = set() self._scheduled_job_registries: List[ScheduledJobRegistry] = [] self.lock_acquisition_time = None - self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection) + self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection) self.serializer = resolve_serializer(serializer) self._connection = None @@ -71,7 +71,7 @@ class RQScheduler: if self._connection: return self._connection self._connection = self._connection_class( - connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs) + connection_pool=ConnectionPool(connection_class=self._pool_class, **self._pool_kwargs) ) return self._connection diff --git a/rq/worker_pool.py b/rq/worker_pool.py index 4bd21bb..005c3b9 100644 --- a/rq/worker_pool.py +++ b/rq/worker_pool.py @@ -11,7 +11,7 @@ from typing import Dict, List, NamedTuple, Optional, Set, Type, Union from uuid import uuid4 from redis import Redis -from redis import SSLConnection, UnixDomainSocketConnection +from redis import ConnectionPool from rq.serializers import DefaultSerializer from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty @@ -65,7 +65,7 @@ class WorkerPool: # A dictionary of WorkerData keyed by worker name self.worker_dict: Dict[str, WorkerData] = {} - self._connection_class, _, self._connection_kwargs = parse_connection(connection) + self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection) @property def queues(self) -> List[Queue]: @@ -158,7 +158,7 @@ class WorkerPool: name = uuid4().hex process = Process( target=run_worker, - args=(name, self._queue_names, self._connection_class, self._connection_kwargs), + args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs), kwargs={ '_sleep': _sleep, 'burst': burst, @@ -234,7 +234,8 @@ def run_worker( worker_name: str, queue_names: List[str], connection_class, - connection_kwargs: dict, + connection_pool_class, + connection_pool_kwargs: dict, worker_class: Type[BaseWorker] = Worker, serializer: Type[DefaultSerializer] = DefaultSerializer, job_class: Type[Job] = Job, @@ -242,7 +243,9 @@ def run_worker( logging_level: str = "INFO", _sleep: int = 0, ): - connection = connection_class(**connection_kwargs) + connection = connection_class( + connection_pool=ConnectionPool(connection_class=connection_pool_class, **connection_pool_kwargs) + ) queues = [Queue(name, connection=connection) for name in queue_names] worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class) worker.log.info("Starting worker started with PID %s", os.getpid()) diff --git a/tests/test_connection.py b/tests/test_connection.py index 0b64d2b..5ac76d6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -from redis import ConnectionPool, Redis, UnixDomainSocketConnection +from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection from rq import Connection, Queue from rq.connections import parse_connection @@ -38,10 +38,14 @@ class TestConnectionInheritance(RQTestCase): self.assertEqual(q2.connection, job2.connection) def test_parse_connection(self): - """Test parsing `ssl` and UnixDomainSocketConnection""" - _, _, kwargs = parse_connection(Redis(ssl=True)) - self.assertTrue(kwargs['ssl']) + """Test parsing the connection""" + conn_class, pool_class, pool_kwargs = parse_connection(Redis(ssl=True)) + self.assertEqual(conn_class, Redis) + self.assertEqual(pool_class, SSLConnection) + path = '/tmp/redis.sock' pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path) - _, _, kwargs = parse_connection(Redis(connection_pool=pool)) - self.assertTrue(kwargs['unix_socket_path'], path) + conn_class, pool_class, pool_kwargs = parse_connection(Redis(connection_pool=pool)) + self.assertEqual(conn_class, Redis) + self.assertEqual(pool_class, UnixDomainSocketConnection) + self.assertEqual(pool_kwargs, {"path": path}) diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py index c836309..ab2e677 100644 --- a/tests/test_worker_pool.py +++ b/tests/test_worker_pool.py @@ -8,6 +8,7 @@ from rq.job import JobStatus from tests import TestCase from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello +from rq.connections import parse_connection from rq.queue import Queue from rq.serializers import JSONSerializer from rq.worker import SimpleWorker @@ -108,8 +109,10 @@ class TestWorkerPool(TestCase): """Ensure run_worker() properly spawns a Worker""" queue = Queue('foo', connection=self.connection) queue.enqueue(say_hello) + + connection_class, pool_class, pool_kwargs = parse_connection(self.connection) run_worker( - 'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy() + 'test-worker', ['foo'], connection_class, pool_class, pool_kwargs ) # Worker should have processed the job self.assertEqual(len(queue), 0)