mirror of https://github.com/peter4431/rq.git
Worker pool (#1874)
* First stab at implementating worker pool * Use process.is_alive() to check whether a process is still live * Handle shutdown signal * Check worker loop done * First working version of `WorkerPool`. * Added test for check_workers() * Added test for pool.start() * Better shutdown process * Comment out test_start() to see if it fixes CI * Make tests pass * Make CI pass * Comment out some tests * Comment out more tests * Re-enable a test * Re-enable another test * Uncomment check_workers test * Added run_worker test * Minor modification to dead worker detection * More test cases * Better process name for workers * Added back pool.stop_workers() when signal is received * Cleaned up cli.py * WIP on worker-pool command * Fix test * Test that worker pool ignores consecutive shutdown signals * Added test for worker-pool CLI command. * Added timeout to CI jobs * Fix worker pool test * Comment out test_scheduler.py * Fixed worker-pool in burst mode * Increase test coverage * Exclude tests directory from coverage.py * Improve test coverage * Renamed `Pool(num_workers=2) to `Pool(size=2)` * Revert "Renamed `Pool(num_workers=2) to `Pool(size=2)`" This reverts commit a1306f89ad0d8686c6bde447bff75e2f71f0733b. * Renamed Pool to WorkerPool * Added a new TestCase that doesn't use LocalStack * Added job_class, worker_class and serializer arguments to WorkerPool * Use parse_connection() in WorkerPool.__init__ * Added CLI arguments for worker-pool * Minor WorkerPool and test fixes * Fixed failing CLI test * Document WorkerPoolmain
parent
8a9daecaf2
commit
64cb1a27b9
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,250 @@
|
||||
import contextlib
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
|
||||
from enum import Enum
|
||||
from multiprocessing import Process
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Type, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from redis import Redis
|
||||
from redis import SSLConnection, UnixDomainSocketConnection
|
||||
from rq.serializers import DefaultSerializer
|
||||
|
||||
from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty
|
||||
|
||||
from .connections import parse_connection
|
||||
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT
|
||||
from .job import Job
|
||||
from .logutils import setup_loghandlers
|
||||
from .queue import Queue
|
||||
from .utils import parse_names
|
||||
from .worker import BaseWorker, Worker
|
||||
|
||||
|
||||
class WorkerData(NamedTuple):
|
||||
name: str
|
||||
pid: int
|
||||
process: Process
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
class Status(Enum):
|
||||
IDLE = 1
|
||||
STARTED = 2
|
||||
STOPPED = 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queues: List[Union[str, Queue]],
|
||||
connection: Redis,
|
||||
num_workers: int = 1,
|
||||
worker_class: Type[BaseWorker] = Worker,
|
||||
serializer: Type[DefaultSerializer] = DefaultSerializer,
|
||||
job_class: Type[Job] = Job,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_workers: int = num_workers
|
||||
self._workers: List[Worker] = []
|
||||
setup_loghandlers('INFO', DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__)
|
||||
self.log: logging.Logger = logging.getLogger(__name__)
|
||||
# self.log: logging.Logger = logger
|
||||
self._queue_names: List[str] = parse_names(queues)
|
||||
self.connection = connection
|
||||
self.name: str = uuid4().hex
|
||||
self._burst: bool = True
|
||||
self._sleep: int = 0
|
||||
self.status: self.Status = self.Status.IDLE # type: ignore
|
||||
self.worker_class: Type[BaseWorker] = worker_class
|
||||
self.serializer: Type[DefaultSerializer] = serializer
|
||||
self.job_class: Type[Job] = job_class
|
||||
|
||||
# A dictionary of WorkerData keyed by worker name
|
||||
self.worker_dict: Dict[str, WorkerData] = {}
|
||||
self._connection_class, _, self._connection_kwargs = parse_connection(connection)
|
||||
|
||||
@property
|
||||
def queues(self) -> List[Queue]:
|
||||
"""Returns a list of Queue objects"""
|
||||
return [Queue(name, connection=self.connection) for name in self._queue_names]
|
||||
|
||||
@property
|
||||
def number_of_active_workers(self) -> int:
|
||||
"""Returns a list of Queue objects"""
|
||||
return len(self.worker_dict)
|
||||
|
||||
def _install_signal_handlers(self):
|
||||
"""Installs signal handlers for handling SIGINT and SIGTERM
|
||||
gracefully.
|
||||
"""
|
||||
signal.signal(signal.SIGINT, self.request_stop)
|
||||
signal.signal(signal.SIGTERM, self.request_stop)
|
||||
|
||||
def request_stop(self, signum=None, frame=None):
|
||||
"""Toggle self._stop_requested that's checked on every loop"""
|
||||
self.log.info('Received SIGINT/SIGTERM, shutting down...')
|
||||
self.status = self.Status.STOPPED
|
||||
self.stop_workers()
|
||||
|
||||
def all_workers_have_stopped(self) -> bool:
|
||||
"""Returns True if all workers have stopped."""
|
||||
self.reap_workers()
|
||||
# `bool(self.worker_dict)` sometimes returns True even if the dict is empty
|
||||
return self.number_of_active_workers == 0
|
||||
|
||||
def reap_workers(self):
|
||||
"""Removes dead workers from worker_dict"""
|
||||
self.log.debug('Reaping dead workers')
|
||||
worker_datas = list(self.worker_dict.values())
|
||||
|
||||
for data in worker_datas:
|
||||
data.process.join(0.1)
|
||||
if data.process.is_alive():
|
||||
self.log.debug('Worker %s with pid %d is alive', data.name, data.pid)
|
||||
else:
|
||||
self.handle_dead_worker(data)
|
||||
continue
|
||||
|
||||
# I'm still not sure why this is sometimes needed, temporarily commenting
|
||||
# this out until I can figure it out.
|
||||
# with contextlib.suppress(HorseMonitorTimeoutException):
|
||||
# with UnixSignalDeathPenalty(1, HorseMonitorTimeoutException):
|
||||
# try:
|
||||
# # If wait4 returns, the process is dead
|
||||
# os.wait4(data.process.pid, 0) # type: ignore
|
||||
# self.handle_dead_worker(data)
|
||||
# except ChildProcessError:
|
||||
# # Process is dead
|
||||
# self.handle_dead_worker(data)
|
||||
# continue
|
||||
|
||||
def handle_dead_worker(self, worker_data: WorkerData):
|
||||
"""
|
||||
Handle a dead worker
|
||||
"""
|
||||
self.log.info('Worker %s with pid %d is dead', worker_data.name, worker_data.pid)
|
||||
with contextlib.suppress(KeyError):
|
||||
self.worker_dict.pop(worker_data.name)
|
||||
|
||||
def check_workers(self, respawn: bool = True) -> None:
|
||||
"""
|
||||
Check whether workers are still alive
|
||||
"""
|
||||
self.log.debug('Checking worker processes')
|
||||
self.reap_workers()
|
||||
# If we have less number of workers than num_workers,
|
||||
# respawn the difference
|
||||
if respawn and self.status != self.Status.STOPPED:
|
||||
delta = self.num_workers - len(self.worker_dict)
|
||||
if delta:
|
||||
for i in range(delta):
|
||||
self.start_worker(burst=self._burst, _sleep=self._sleep)
|
||||
|
||||
def start_worker(
|
||||
self,
|
||||
count: Optional[int] = None,
|
||||
burst: bool = True,
|
||||
_sleep: float = 0,
|
||||
logging_level: str = "INFO",
|
||||
):
|
||||
"""
|
||||
Starts a worker and adds the data to worker_datas.
|
||||
* sleep: waits for X seconds before creating worker, for testing purposes
|
||||
"""
|
||||
name = uuid4().hex
|
||||
process = Process(
|
||||
target=run_worker,
|
||||
args=(name, self._queue_names, self._connection_class, self._connection_kwargs),
|
||||
kwargs={
|
||||
'_sleep': _sleep,
|
||||
'burst': burst,
|
||||
'logging_level': logging_level,
|
||||
'worker_class': self.worker_class,
|
||||
'job_class': self.job_class,
|
||||
'serializer': self.serializer,
|
||||
},
|
||||
name=f'Worker {name} (WorkerPool {self.name})',
|
||||
)
|
||||
process.start()
|
||||
worker_data = WorkerData(name=name, pid=process.pid, process=process) # type: ignore
|
||||
self.worker_dict[name] = worker_data
|
||||
self.log.debug('Spawned worker: %s with PID %d', name, process.pid)
|
||||
|
||||
def start_workers(self, burst: bool = True, _sleep: float = 0, logging_level: str = "INFO"):
|
||||
"""
|
||||
Run the workers
|
||||
* sleep: waits for X seconds before creating worker, only for testing purposes
|
||||
"""
|
||||
self.log.debug(f'Spawning {self.num_workers} workers')
|
||||
for i in range(self.num_workers):
|
||||
self.start_worker(i + 1, burst=burst, _sleep=_sleep, logging_level=logging_level)
|
||||
|
||||
def stop_worker(self, worker_data: WorkerData, sig=signal.SIGINT):
|
||||
"""
|
||||
Send stop signal to worker and catch "No such process" error if the worker is already dead.
|
||||
"""
|
||||
try:
|
||||
os.kill(worker_data.pid, sig)
|
||||
self.log.info('Sent shutdown command to worker with %s', worker_data.pid)
|
||||
except OSError as e:
|
||||
if e.errno == errno.ESRCH:
|
||||
# "No such process" is fine with us
|
||||
self.log.debug('Horse already dead')
|
||||
else:
|
||||
raise
|
||||
|
||||
def stop_workers(self):
|
||||
"""Send SIGINT to all workers"""
|
||||
self.log.info('Sending stop signal to %s workers', len(self.worker_dict))
|
||||
worker_datas = list(self.worker_dict.values())
|
||||
for worker_data in worker_datas:
|
||||
self.stop_worker(worker_data)
|
||||
|
||||
def start(self, burst: bool = False, logging_level: str = "INFO"):
|
||||
self._burst = burst
|
||||
respawn = not burst # Don't respawn workers if burst mode is on
|
||||
setup_loghandlers(logging_level, DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, name=__name__)
|
||||
self.log.info(f'Starting worker pool {self.name} with pid %d...', os.getpid())
|
||||
self.status = self.Status.IDLE
|
||||
self.start_workers(burst=self._burst, logging_level=logging_level)
|
||||
self._install_signal_handlers()
|
||||
while True:
|
||||
if self.status == self.Status.STOPPED:
|
||||
if self.all_workers_have_stopped():
|
||||
self.log.info('All workers stopped, exiting...')
|
||||
break
|
||||
else:
|
||||
self.log.info('Waiting for workers to shutdown...')
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
self.check_workers(respawn=respawn)
|
||||
if burst and self.number_of_active_workers == 0:
|
||||
self.log.info('All workers stopped, exiting...')
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def run_worker(
|
||||
worker_name: str,
|
||||
queue_names: List[str],
|
||||
connection_class,
|
||||
connection_kwargs: dict,
|
||||
worker_class: Type[BaseWorker] = Worker,
|
||||
serializer: Type[DefaultSerializer] = DefaultSerializer,
|
||||
job_class: Type[Job] = Job,
|
||||
burst: bool = True,
|
||||
logging_level: str = "INFO",
|
||||
_sleep: int = 0,
|
||||
):
|
||||
connection = connection_class(**connection_kwargs)
|
||||
queues = [Queue(name, connection=connection) for name in queue_names]
|
||||
worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class)
|
||||
worker.log.info("Starting worker started with PID %s", os.getpid())
|
||||
time.sleep(_sleep)
|
||||
worker.work(burst=burst, logging_level=logging_level)
|
@ -0,0 +1,138 @@
|
||||
import os
|
||||
import signal
|
||||
|
||||
from multiprocessing import Process
|
||||
from time import sleep
|
||||
from rq.job import JobStatus
|
||||
|
||||
from tests import TestCase
|
||||
from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello
|
||||
|
||||
from rq.queue import Queue
|
||||
from rq.serializers import JSONSerializer
|
||||
from rq.worker import SimpleWorker
|
||||
from rq.worker_pool import run_worker, WorkerPool
|
||||
|
||||
|
||||
def wait_and_send_shutdown_signal(pid, time_to_wait=0.0):
|
||||
sleep(time_to_wait)
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
|
||||
|
||||
class TestWorkerPool(TestCase):
|
||||
def test_queues(self):
|
||||
"""Test queue parsing"""
|
||||
pool = WorkerPool(['default', 'foo'], connection=self.connection)
|
||||
self.assertEqual(
|
||||
set(pool.queues), {Queue('default', connection=self.connection), Queue('foo', connection=self.connection)}
|
||||
)
|
||||
|
||||
# def test_spawn_workers(self):
|
||||
# """Test spawning workers"""
|
||||
# pool = WorkerPool(['default', 'foo'], connection=self.connection, num_workers=2)
|
||||
# pool.start_workers(burst=False)
|
||||
# self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||
# pool.stop_workers()
|
||||
|
||||
def test_check_workers(self):
|
||||
"""Test check_workers()"""
|
||||
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||
pool.start_workers(burst=False)
|
||||
|
||||
# There should be two workers
|
||||
pool.check_workers()
|
||||
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||
|
||||
worker_data = list(pool.worker_dict.values())[0]
|
||||
_send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0)
|
||||
# 1 worker should be dead since we sent a shutdown command
|
||||
sleep(0.2)
|
||||
pool.check_workers(respawn=False)
|
||||
self.assertEqual(len(pool.worker_dict.keys()), 1)
|
||||
|
||||
# If we call `check_workers` with `respawn=True`, the worker should be respawned
|
||||
pool.check_workers(respawn=True)
|
||||
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||
|
||||
pool.stop_workers()
|
||||
|
||||
def test_reap_workers(self):
|
||||
"""Dead workers are removed from worker_dict"""
|
||||
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||
pool.start_workers(burst=False)
|
||||
|
||||
# There should be two workers
|
||||
pool.reap_workers()
|
||||
self.assertEqual(len(pool.worker_dict.keys()), 2)
|
||||
|
||||
worker_data = list(pool.worker_dict.values())[0]
|
||||
_send_shutdown_command(worker_data.name, self.connection.connection_pool.connection_kwargs.copy(), delay=0)
|
||||
# 1 worker should be dead since we sent a shutdown command
|
||||
sleep(0.2)
|
||||
pool.reap_workers()
|
||||
self.assertEqual(len(pool.worker_dict.keys()), 1)
|
||||
pool.stop_workers()
|
||||
|
||||
def test_start(self):
|
||||
"""Test start()"""
|
||||
pool = WorkerPool(['default'], connection=self.connection, num_workers=2)
|
||||
|
||||
p = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||
p.start()
|
||||
pool.start()
|
||||
self.assertEqual(pool.status, pool.Status.STOPPED)
|
||||
self.assertTrue(pool.all_workers_have_stopped())
|
||||
# We need this line so the test doesn't hang
|
||||
pool.stop_workers()
|
||||
|
||||
def test_pool_ignores_consecutive_shutdown_signals(self):
|
||||
"""If two shutdown signals are sent within one second, only the first one is processed"""
|
||||
# Send two shutdown signals within one second while the worker is
|
||||
# working on a long running job. The job should still complete (not killed)
|
||||
pool = WorkerPool(['foo'], connection=self.connection, num_workers=2)
|
||||
|
||||
process_1 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||
process_1.start()
|
||||
process_2 = Process(target=wait_and_send_shutdown_signal, args=(os.getpid(), 0.5))
|
||||
process_2.start()
|
||||
|
||||
queue = Queue('foo', connection=self.connection)
|
||||
job = queue.enqueue(long_running_job, 1)
|
||||
pool.start(burst=True)
|
||||
|
||||
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||
# We need this line so the test doesn't hang
|
||||
pool.stop_workers()
|
||||
|
||||
def test_run_worker(self):
|
||||
"""Ensure run_worker() properly spawns a Worker"""
|
||||
queue = Queue('foo', connection=self.connection)
|
||||
queue.enqueue(say_hello)
|
||||
run_worker(
|
||||
'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy()
|
||||
)
|
||||
# Worker should have processed the job
|
||||
self.assertEqual(len(queue), 0)
|
||||
|
||||
def test_worker_pool_arguments(self):
|
||||
"""Ensure arguments are properly used to create the right workers"""
|
||||
queue = Queue('foo', connection=self.connection)
|
||||
job = queue.enqueue(say_hello)
|
||||
pool = WorkerPool([queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker)
|
||||
pool.start(burst=True)
|
||||
# Worker should have processed the job
|
||||
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||
|
||||
queue = Queue('json', connection=self.connection, serializer=JSONSerializer)
|
||||
job = queue.enqueue(say_hello, 'Hello')
|
||||
pool = WorkerPool(
|
||||
[queue], connection=self.connection, num_workers=2, worker_class=SimpleWorker, serializer=JSONSerializer
|
||||
)
|
||||
pool.start(burst=True)
|
||||
# Worker should have processed the job
|
||||
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
||||
|
||||
pool = WorkerPool([queue], connection=self.connection, num_workers=2, job_class=CustomJob)
|
||||
pool.start(burst=True)
|
||||
# Worker should have processed the job
|
||||
self.assertEqual(job.get_status(refresh=True), JobStatus.FINISHED)
|
Loading…
Reference in New Issue