mirror of https://github.com/peter4431/rq.git
Worker pool (#1874)
* First stab at implementating worker pool * Use process.is_alive() to check whether a process is still live * Handle shutdown signal * Check worker loop done * First working version of `WorkerPool`. * Added test for check_workers() * Added test for pool.start() * Better shutdown process * Comment out test_start() to see if it fixes CI * Make tests pass * Make CI pass * Comment out some tests * Comment out more tests * Re-enable a test * Re-enable another test * Uncomment check_workers test * Added run_worker test * Minor modification to dead worker detection * More test cases * Better process name for workers * Added back pool.stop_workers() when signal is received * Cleaned up cli.py * WIP on worker-pool command * Fix test * Test that worker pool ignores consecutive shutdown signals * Added test for worker-pool CLI command. * Added timeout to CI jobs * Fix worker pool test * Comment out test_scheduler.py * Fixed worker-pool in burst mode * Increase test coverage * Exclude tests directory from coverage.py * Improve test coverage * Renamed `Pool(num_workers=2) to `Pool(size=2)` * Revert "Renamed `Pool(num_workers=2) to `Pool(size=2)`" This reverts commit a1306f89ad0d8686c6bde447bff75e2f71f0733b. * Renamed Pool to WorkerPool * Added a new TestCase that doesn't use LocalStack * Added job_class, worker_class and serializer arguments to WorkerPool * Use parse_connection() in WorkerPool.__init__ * Added CLI arguments for worker-pool * Minor WorkerPool and test fixes * Fixed failing CLI test * Document WorkerPoolmain
parent
8a9daecaf2
commit
64cb1a27b9
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,250 @@
|
|||||||
|
import contextlib
|
||||||
|
import errno
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from multiprocessing import Process
|
||||||
|
from typing import Dict, List, NamedTuple, Optional, Set, Type, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from redis import Redis
|
||||||
|
from redis import SSLConnection, UnixDomainSocketConnection
|
||||||
|
from rq.serializers import DefaultSerializer
|
||||||
|
|
||||||
|
from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty
|
||||||
|
|
||||||
|
from .connections import parse_connection
|
||||||
|
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT
|
||||||
|
from .job import Job
|
||||||
|
from .logutils import setup_loghandlers
|
||||||
|
from .queue import Queue
|
||||||
|
from .utils import parse_names
|
||||||
|
from .worker import BaseWorker, Worker
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerData(NamedTuple):
|
||||||
|
name: str
|
||||||
|
pid: int
|
||||||
|
process: Process
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerPool:
|
||||||
|
class Status(Enum):
|
||||||
|
IDLE = 1
|
||||||
|
STARTED = 2
|
||||||
|
STOPPED = 3
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
queues: List[Union[str, Queue]],
|
||||||
|
connection: Redis,
|
||||||
|
num_workers: int = 1,
|
||||||
|
worker_class: Type[BaseWorker] = Worker,
|
||||||
|
serializer: Type[DefaultSerializer] = DefaultSerializer,
|
||||||
|
job_class: Type[Job] = Job,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.num_workers: int = num_workers
|
||||||
|
self._workers: List[Worker] = []
|
||||||
|
setup_loghandlers('INFO', DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__)
|
||||||
|
self.log: logging.Logger = logging.getLogger(__name__)
|
||||||
|
# self.log: logging.Logger = logger
|
||||||
|
self._queue_names: List[str] = parse_names(queues)
|
||||||
|
self.connection = connection
|
||||||
|
self.name: str = uuid4().hex
|
||||||
|
self._burst: bool = True
|
||||||
|
self._sleep: int = 0
|
||||||
|
self.status: self.Status = self.Status.IDLE # type: ignore
|
||||||
|
self.worker_class: Type[BaseWorker] = worker_class
|
||||||
|
self.serializer: Type[DefaultSerializer] = serializer
|
||||||
|
self.job_class: Type[Job] = job_class
|
||||||
|
|
||||||
|
# A dictionary of WorkerData keyed by worker name
|
||||||
|
self.worker_dict: Dict[str, WorkerData] = {}
|
||||||
|
self._connection_class, _, self._connection_kwargs = parse_connection(connection)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def queues(self) -> List[Queue]:
|
||||||
|
"""Returns a list of Queue objects"""
|
||||||
|
return [Queue(name, connection=self.connection) for name in self._queue_names]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def number_of_active_workers(self) -> int:
|
||||||
|
"""Returns a list of Queue objects"""
|
||||||
|
return len(self.worker_dict)
|
||||||
|
|
||||||
|
def _install_signal_handlers(self):
|
||||||
|
"""Installs signal handlers for handling SIGINT and SIGTERM
|
||||||
|
gracefully.
|
||||||
|
"""
|
||||||
|
signal.signal(signal.SIGINT, self.request_stop)
|
||||||
|
signal.signal(signal.SIGTERM, self.request_stop)
|
||||||
|
|
||||||
|
def request_stop(self, signum=None, frame=None):
|
||||||
|
"""Toggle self._stop_requested that's checked on every loop"""
|
||||||
|
self.log.info('Received SIGINT/SIGTERM, shutting down...')
|
||||||
|
self.status = self.Status.STOPPED
|
||||||
|
self.stop_workers()
|
||||||
|
|
||||||
|
def all_workers_have_stopped(self) -> bool:
|
||||||
|
"""Returns True if all workers have stopped."""
|
||||||
|
self.reap_workers()
|
||||||
|
# `bool(self.worker_dict)` sometimes returns True even if the dict is empty
|
||||||
|
return self.number_of_active_workers == 0
|
||||||
|
|
||||||
|
def reap_workers(self):
|
||||||
|
"""Removes dead workers from worker_dict"""
|
||||||
|
self.log.debug('Reaping dead workers')
|
||||||
|
worker_datas = list(self.worker_dict.values())
|
||||||
|
|
||||||
|
for data in worker_datas:
|
||||||
|
data.process.join(0.1)
|
||||||
|
if data.process.is_alive():
|
||||||
|
self.log.debug('Worker %s with pid %d is alive', data.name, data.pid)
|
||||||
|
else:
|
||||||
|
self.handle_dead_worker(data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# I'm still not sure why this is sometimes needed, temporarily commenting
|
||||||
|
# this out until I can figure it out.
|
||||||
|
# with contextlib.suppress(HorseMonitorTimeoutException):
|
||||||
|
# with UnixSignalDeathPenalty(1, HorseMonitorTimeoutException):
|
||||||
|
# try:
|
||||||
|
# # If wait4 returns, the process is dead
|
||||||
|
# os.wait4(data.process.pid, 0) # type: ignore
|
||||||
|
# self.handle_dead_worker(data)
|
||||||
|
# except ChildProcessError:
|
||||||
|
# # Process is dead
|
||||||
|
# self.handle_dead_worker(data)
|
||||||
|
# continue
|
||||||
|
|
||||||
|
def handle_dead_worker(self, worker_data: WorkerData):
|
||||||
|
"""
|
||||||
|
Handle a dead worker
|
||||||
|
"""
|
||||||
|
self.log.info('Worker %s with pid %d is dead', worker_data.name, worker_data.pid)
|
||||||
|
with contextlib.suppress(KeyError):
|
||||||
|
self.worker_dict.pop(worker_data.name)
|
||||||
|
|
||||||
|
def check_workers(self, respawn: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Check whether workers are still alive
|
||||||
|
"""
|
||||||
|
self.log.debug('Checking worker processes')
|
||||||
|
self.reap_workers()
|
||||||
|
# If we have less number of workers than num_workers,
|
||||||
|
# respawn the difference
|
||||||
|
if respawn and self.status != self.Status.STOPPED:
|
||||||
|
delta = self.num_workers - len(self.worker_dict)
|
||||||
|
if delta:
|
||||||
|
for i in range(delta):
|
||||||
|
self.start_worker(burst=self._burst, _sleep=self._sleep)
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
self,
|
||||||
|
count: Optional[int] = None,
|
||||||
|
burst: bool = True,
|
||||||
|
_sleep: float = 0,
|
||||||
|
logging_level: str = "INFO",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Starts a worker and adds the data to worker_datas.
|
||||||
|
* sleep: waits for X seconds before creating worker, for testing purposes
|
||||||
|
"""
|
||||||
|
name = uuid4().hex
|
||||||
|
process = Process(
|
||||||
|
target=run_worker,
|
||||||
|
args=(name, self._queue_names, self._connection_class, self._connection_kwargs),
|
||||||
|
kwargs={
|
||||||
|
'_sleep': _sleep,
|
||||||
|
'burst': burst,
|
||||||
|
'logging_level': logging_level,
|
||||||
|
'worker_class': self.worker_class,
|
||||||
|
'job_class': self.job_class,
|
||||||
|
'serializer': self.serializer,
|
||||||
|
},
|
||||||
|
name=f'Worker {name} (WorkerPool {self.name})',
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
worker_data = WorkerData(name=name, pid=process.pid, process=process) # type: ignore
|
||||||
|
self.worker_dict[name] = worker_data
|
||||||
|
self.log.debug('Spawned worker: %s with PID %d', name, process.pid)
|
||||||
|
|
||||||
|
def start_workers(self, burst: bool = True, _sleep: float = 0, logging_level: str = "INFO"):
|
||||||
|
"""
|
||||||
|
Run the workers
|
||||||
|
* sleep: waits for X seconds before creating worker, only for testing purposes
|
||||||
|
"""
|
||||||
|
self.log.debug(f'Spawning {self.num_workers} workers')
|
||||||
|
for i in range(self.num_workers):
|
||||||
|
self.start_worker(i + 1, burst=burst, _sleep=_sleep, logging_level=logging_level)
|
||||||
|
|
||||||
|
def stop_worker(self, worker_data: WorkerData, sig=signal.SIGINT):
|
||||||
|
"""
|
||||||
|
Send stop signal to worker and catch "No such process" error if the worker is already dead.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
os.kill(worker_data.pid, sig)
|
||||||
|
self.log.info('Sent shutdown command to worker with %s', worker_data.pid)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.ESRCH:
|
||||||
|
# "No such process" is fine with us
|
||||||
|
self.log.debug('Horse already dead')
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def stop_workers(self):
|
||||||
|
"""Send SIGINT to all workers"""
|
||||||
|
self.log.info('Sending stop signal to %s workers', len(self.worker_dict))
|
||||||
|
worker_datas = list(self.worker_dict.values())
|
||||||
|
for worker_data in worker_datas:
|
||||||
|
self.stop_worker(worker_data)
|
||||||
|
|
||||||
|
def start(self, burst: bool = False, logging_level: str = "INFO"):
|
||||||
|
self._burst = burst
|
||||||
|
respawn = not burst # Don't respawn workers if burst mode is on
|
||||||
|
setup_loghandlers(logging_level, DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__)
|
||||||
|
self.log.info(f'Starting worker pool {self.name} with pid %d...', os.getpid())
|
||||||
|
self.status = self.Status.IDLE
|
||||||
|
self.start_workers(burst=self._burst, logging_level=logging_level)
|
||||||
|
self._install_signal_handlers()
|
||||||
|
while True:
|
||||||
|
if self.status == self.Status.STOPPED:
|
||||||
|
if self.all_workers_have_stopped():
|
||||||
|
self.log.info('All workers stopped, exiting...')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.log.info('Waiting for workers to shutdown...')
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self.check_workers(respawn=respawn)
|
||||||
|
if burst and self.number_of_active_workers == 0:
|
||||||
|
self.log.info('All workers stopped, exiting...')
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker(
|
||||||
|
worker_name: str,
|
||||||
|
queue_names: List[str],
|
||||||
|
connection_class,
|
||||||
|
connection_kwargs: dict,
|
||||||
|
worker_class: Type[BaseWorker] = Worker,
|
||||||
|
serializer: Type[DefaultSerializer] = DefaultSerializer,
|
||||||
|
job_class: Type[Job] = Job,
|
||||||
|
burst: bool = True,
|
||||||
|
logging_level: str = "INFO",
|
||||||
|
_sleep: int = 0,
|
||||||
|
):
|
||||||
|
connection = connection_class(**connection_kwargs)
|
||||||
|
queues = [Queue(name, connection=connection) for name in queue_names]
|
||||||
|
worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class)
|
||||||
|
worker.log.info("Starting worker started with PID %s", os.getpid())
|
||||||
|
time.sleep(_sleep)
|
||||||
|
worker.work(burst=burst, logging_level=logging_level)
|
@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
import signal
|
||||||
|
|
||||||
|
from multiprocessing import Process
|
||||||
|
from time import sleep
|
||||||
|
from rq.job import JobStatus
|
||||||
|
|
||||||
|
from tests import TestCase
|
||||||
|
from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello
|
||||||
|
|
||||||
|
from rq.queue import Queue
|
||||||
|
from rq.serializers import JSONSerializer
|
||||||
|
from rq.worker import SimpleWorker
|
||||||
|
from rq.worker_pool import run_worker, WorkerPool
|
||||||
|
|
||||||
|
|
||||||
|
def wait_and_send_shutdown_signal(pid, time_to_wait=0.0):
|
||||||
|
sleep(time_to_wait)
|
||||||
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerPool(TestCase):
|
||||||
|
def test_queues(self):
|
||||||
|
"""Test queue parsing"""
|
||||||
|
pool = WorkerPool(['default', 'foo'], connection=self.connection)
|
||||||
|
self.assertEqual(
|
||||||
|
set(pool.queues), {Queue('default', connection=self.connection), Queue('foo', connection=self.connection)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# def test_spawn_workers(self):
|
||||||
|
# """Test spawning workers"""
|
||||||
|
# pool = WorkerPool(['default', 'foo'], connection=self.connection, num_workers=2)
|
||||||
|
# pool.start_workers(burst=False)
|
||||||
|
# self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||||
|
# pool.stop_workers()
|
||||||
|
|
||||||
|
def test_check_workers(self):
|
||||||
|
"""Test check_workers()"""
|
||||||
|
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||||
|
pool.start_workers(burst=False)
|
||||||
|
|
||||||
|
# There should be two workers
|
||||||
|
pool.check_workers()
|
||||||
|
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||||
|
|
||||||
|
worker_data = list(pool.worker_dict.values())[0]
|
||||||
|
_send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0)
|
||||||
|
# 1 worker should be dead since we sent a shutdown command
|
||||||
|
sleep(0.2)
|
||||||
|
pool.check_workers(respawn=False)
|
||||||
|
self.assertEqual(len(pool.worker_dict.keys()), 1)
|
||||||
|
|
||||||
|
# If we call `check_workers` with `respawn=True`, the worker should be respawned
|
||||||
|
pool.check_workers(respawn=True)
|
||||||
|
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||||
|
|
||||||
|
pool.stop_workers()
|
||||||
|
|
||||||
|
def test_reap_workers(self):
|
||||||
|
"""Dead workers are removed from worker_dict"""
|
||||||
|
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||||
|
pool.start_workers(burst=False)
|
||||||
|
|
||||||
|
# There should be two workers
|
||||||
|
pool.reap_workers()
|
||||||
|
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||||
|
|
||||||
|
worker_data = list(pool.worker_dict.values())[0]
|
||||||
|
_send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0)
|
||||||
|
# 1 worker should be dead since we sent a shutdown command
|
||||||
|
sleep(0.2)
|
||||||
|
pool.reap_workers()
|
||||||
|
self.assertEqual(len(pool.worker_dict.keys()), 1)
|
||||||
|
pool.stop_workers()
|
||||||
|
|
||||||
|
def test_start(self):
|
||||||
|
"""Test start()"""
|
||||||
|
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||||
|
|
||||||
|
p = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||||
|
p.start()
|
||||||
|
pool.start()
|
||||||
|
self.assertEqual(pool.status, pool.Status.STOPPED)
|
||||||
|
self.assertTrue(pool.all_workers_have_stopped())
|
||||||
|
# We need this line so the test doesn't hang
|
||||||
|
pool.stop_workers()
|
||||||
|
|
||||||
|
def test_pool_ignores_consecutive_shutdown_signals(self):
|
||||||
|
"""If two shutdown signals are sent within one second, only the first one is processed"""
|
||||||
|
# Send two shutdown signals within one second while the worker is
|
||||||
|
# working on a long running job. The job should still complete (not killed)
|
||||||
|
pool = WorkerPool(['foo'], connection=self.connection, num_workers=2)
|
||||||
|
|
||||||
|
process_1 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||||
|
process_1.start()
|
||||||
|
process_2 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||||
|
process_2.start()
|
||||||
|
|
||||||
|
queue = Queue('foo', connection=self.connection)
|
||||||
|
job = queue.enqueue(long_running_job, 1)
|
||||||
|
pool.start(burst=True)
|
||||||
|
|
||||||
|
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||||
|
# We need this line so the test doesn't hang
|
||||||
|
pool.stop_workers()
|
||||||
|
|
||||||
|
def test_run_worker(self):
|
||||||
|
"""Ensure run_worker() properly spawns a Worker"""
|
||||||
|
queue = Queue('foo', connection=self.connection)
|
||||||
|
queue.enqueue(say_hello)
|
||||||
|
run_worker(
|
||||||
|
'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy()
|
||||||
|
)
|
||||||
|
# Worker should have processed the job
|
||||||
|
self.assertEqual(len(queue), 0)
|
||||||
|
|
||||||
|
def test_worker_pool_arguments(self):
|
||||||
|
"""Ensure arguments are properly used to create the right workers"""
|
||||||
|
queue = Queue('foo', connection=self.connection)
|
||||||
|
job = queue.enqueue(say_hello)
|
||||||
|
pool = WorkerPool([queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker)
|
||||||
|
pool.start(burst=True)
|
||||||
|
# Worker should have processed the job
|
||||||
|
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||||
|
|
||||||
|
queue = Queue('json', connection=self.connection, serializer=JSONSerializer)
|
||||||
|
job = queue.enqueue(say_hello, 'Hello')
|
||||||
|
pool = WorkerPool(
|
||||||
|
[queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker, serializer=JSONSerializer
|
||||||
|
)
|
||||||
|
pool.start(burst=True)
|
||||||
|
# Worker should have processed the job
|
||||||
|
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||||
|
|
||||||
|
pool = WorkerPool([queue], connection=self.connection, num_workers=2, job_class=CustomJob)
|
||||||
|
pool.start(burst=True)
|
||||||
|
# Worker should have processed the job
|
||||||
|
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
Loading…
Reference in New Issue