[Hotfix] Fix SSL connection for scheduler (#1894)

* fix: ssl

* fix: reinstate a test for parse_connection
main
Cyril Chapellier 2 years ago committed by GitHub
parent 07fef85dd2
commit a228b4838c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -118,23 +118,10 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
connection_kwargs.pop('parser_class', None)
connection_pool_kwargs = connection.connection_pool.connection_kwargs.copy()
connection_pool_class = connection.connection_pool.connection_class
if issubclass(connection_pool_class, SSLConnection):
connection_kwargs['ssl'] = True
if issubclass(connection_pool_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
return connection.__class__, connection_pool_class, connection_kwargs
return connection.__class__, connection_pool_class, connection_pool_kwargs
_connection_stack = LocalStack()

@ -50,7 +50,7 @@ class RQScheduler:
self._acquired_locks: Set[str] = set()
self._scheduled_job_registries: List[ScheduledJobRegistry] = []
self.lock_acquisition_time = None
self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection)
self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
self.serializer = resolve_serializer(serializer)
self._connection = None
@ -71,7 +71,7 @@ class RQScheduler:
if self._connection:
return self._connection
self._connection = self._connection_class(
connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs)
connection_pool=ConnectionPool(connection_class=self._pool_class, **self._pool_kwargs)
)
return self._connection

@ -11,7 +11,7 @@ from typing import Dict, List, NamedTuple, Optional, Set, Type, Union
from uuid import uuid4
from redis import Redis
from redis import SSLConnection, UnixDomainSocketConnection
from redis import ConnectionPool
from rq.serializers import DefaultSerializer
from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty
@ -65,7 +65,7 @@ class WorkerPool:
# A dictionary of WorkerData keyed by worker name
self.worker_dict: Dict[str, WorkerData] = {}
self._connection_class, _, self._connection_kwargs = parse_connection(connection)
self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
@property
def queues(self) -> List[Queue]:
@ -158,7 +158,7 @@ class WorkerPool:
name = uuid4().hex
process = Process(
target=run_worker,
args=(name, self._queue_names, self._connection_class, self._connection_kwargs),
args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs),
kwargs={
'_sleep': _sleep,
'burst': burst,
@ -234,7 +234,8 @@ def run_worker(
worker_name: str,
queue_names: List[str],
connection_class,
connection_kwargs: dict,
connection_pool_class,
connection_pool_kwargs: dict,
worker_class: Type[BaseWorker] = Worker,
serializer: Type[DefaultSerializer] = DefaultSerializer,
job_class: Type[Job] = Job,
@ -242,7 +243,9 @@ def run_worker(
logging_level: str = "INFO",
_sleep: int = 0,
):
connection = connection_class(**connection_kwargs)
connection = connection_class(
connection_pool=ConnectionPool(connection_class=connection_pool_class, **connection_pool_kwargs)
)
queues = [Queue(name, connection=connection) for name in queue_names]
worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class)
worker.log.info("Starting worker started with PID %s", os.getpid())

@ -1,4 +1,4 @@
from redis import ConnectionPool, Redis, UnixDomainSocketConnection
from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from rq import Connection, Queue
from rq.connections import parse_connection
@ -38,10 +38,14 @@ class TestConnectionInheritance(RQTestCase):
self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self):
"""Test parsing `ssl` and UnixDomainSocketConnection"""
_, _, kwargs = parse_connection(Redis(ssl=True))
self.assertTrue(kwargs['ssl'])
"""Test parsing the connection"""
conn_class, pool_class, pool_kwargs = parse_connection(Redis(ssl=True))
self.assertEqual(conn_class, Redis)
self.assertEqual(pool_class, SSLConnection)
path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
_, _, kwargs = parse_connection(Redis(connection_pool=pool))
self.assertTrue(kwargs['unix_socket_path'], path)
conn_class, pool_class, pool_kwargs = parse_connection(Redis(connection_pool=pool))
self.assertEqual(conn_class, Redis)
self.assertEqual(pool_class, UnixDomainSocketConnection)
self.assertEqual(pool_kwargs, {"path": path})

@ -8,6 +8,7 @@ from rq.job import JobStatus
from tests import TestCase
from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello
from rq.connections import parse_connection
from rq.queue import Queue
from rq.serializers import JSONSerializer
from rq.worker import SimpleWorker
@ -108,8 +109,10 @@ class TestWorkerPool(TestCase):
"""Ensure run_worker() properly spawns a Worker"""
queue = Queue('foo', connection=self.connection)
queue.enqueue(say_hello)
connection_class, pool_class, pool_kwargs = parse_connection(self.connection)
run_worker(
'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy()
'test-worker', ['foo'], connection_class, pool_class, pool_kwargs
)
# Worker should have processed the job
self.assertEqual(len(queue), 0)

Loading…
Cancel
Save