From 6fc9454675caa64ff81dc4801adb19bff1b8d5af Mon Sep 17 00:00:00 2001 From: Cyrille Lavigne Date: Sat, 12 Jun 2021 00:45:03 -0400 Subject: [PATCH] Handle deserializing failures gracefully (#1428) * adds unit test for a deserialization error This tests that deserialization exceptions are properly logged, and fails in the manner described in #1422 . * Catch deserializing errors in Worker.handle_exception() This fixes #1422 , and makes tests/test_worker.py::TestWorker::test_deserializing_failure_is_handled pass. * made unit test less specific This is required to get the test to pass under other serializers / other python versions. * Added generic DeserializationError * switched ValueError to DeserializationError in a test The changed test is creating an invalid job, which now raises DeserializationError when data is accessed, as opposed to ValueError. --- rq/exceptions.py | 4 ++++ rq/job.py | 8 ++++++-- rq/worker.py | 27 +++++++++++++++++++-------- tests/test_job.py | 10 +++++----- tests/test_worker.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 15 deletions(-) diff --git a/rq/exceptions.py b/rq/exceptions.py index ba97f7c..34e45b1 100644 --- a/rq/exceptions.py +++ b/rq/exceptions.py @@ -7,6 +7,10 @@ class NoSuchJobError(Exception): pass +class DeserializationError(Exception): + pass + + class InvalidJobDependency(Exception): pass diff --git a/rq/job.py b/rq/job.py index d13de49..fc74578 100644 --- a/rq/job.py +++ b/rq/job.py @@ -18,7 +18,7 @@ from uuid import uuid4 from rq.compat import as_text, decode_redis_hash, string_types from .connections import resolve_connection -from .exceptions import NoSuchJobError +from .exceptions import DeserializationError, NoSuchJobError from .local import LocalStack from .serializers import resolve_serializer from .utils import (get_version, import_attribute, parse_timeout, str_to_date, @@ -221,7 +221,11 @@ class Job: return import_attribute(self.func_name) def _deserialize_data(self): - self._func_name, self._instance, self._args, self._kwargs = self.serializer.loads(self.data) + try: + self._func_name, self._instance, self._args, self._kwargs = self.serializer.loads(self.data) + except Exception as e: + # catch anything because serializers are generic + raise DeserializationError() from e @property def data(self): diff --git a/rq/worker.py b/rq/worker.py index e22a391..1ed913b 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -34,7 +34,7 @@ from .connections import get_current_connection, push_connection, pop_connection from .defaults import (DEFAULT_RESULT_TTL, DEFAULT_WORKER_TTL, DEFAULT_JOB_MONITORING_INTERVAL, DEFAULT_LOGGING_FORMAT, DEFAULT_LOGGING_DATE_FORMAT) -from .exceptions import DequeueTimeout, ShutDownImminentException +from .exceptions import DeserializationError, DequeueTimeout, ShutDownImminentException from .job import Job, JobStatus from .logutils import setup_loghandlers from .queue import Queue @@ -1057,13 +1057,24 @@ class Worker: def handle_exception(self, job, *exc_info): """Walks the exception handler stack to delegate exception handling.""" exc_string = ''.join(traceback.format_exception(*exc_info)) - self.log.error(exc_string, extra={ - 'func': job.func_name, - 'arguments': job.args, - 'kwargs': job.kwargs, - 'queue': job.origin, - 'job_id': job.id, - }) + + # If the job cannot be deserialized, it will raise when func_name or + # the other properties are accessed, which will stop exceptions from + # being properly logged, so we guard against it here. + try: + extra = { + 'func': job.func_name, + 'arguments': job.args, + 'kwargs': job.kwargs, + } + except DeserializationError: + extra = {} + + # the properties below should be safe however + extra.update({'queue': job.origin, 'job_id': job.id}) + + # func_name + self.log.error(exc_string, exc_info=True, extra=extra) for handler in self._exc_handlers: self.log.debug('Invoking exception handler %s', handler) diff --git a/tests/test_job.py b/tests/test_job.py index 0d43d2b..482434d 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta from redis import WatchError from rq.compat import as_text -from rq.exceptions import NoSuchJobError +from rq.exceptions import DeserializationError, NoSuchJobError from rq.job import Job, JobStatus, cancel_job, get_current_job from rq.queue import Queue from rq.registry import (DeferredJobRegistry, FailedJobRegistry, @@ -53,13 +53,13 @@ class TestJob(RQTestCase): self.assertIsNone(job.result) self.assertIsNone(job.exc_info) - with self.assertRaises(ValueError): + with self.assertRaises(DeserializationError): job.func - with self.assertRaises(ValueError): + with self.assertRaises(DeserializationError): job.instance - with self.assertRaises(ValueError): + with self.assertRaises(DeserializationError): job.args - with self.assertRaises(ValueError): + with self.assertRaises(DeserializationError): job.kwargs def test_create_param_errors(self): diff --git a/tests/test_worker.py b/tests/test_worker.py index 05d6fef..1289acd 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -261,6 +261,38 @@ class TestWorker(RQTestCase): failed_job_registry = FailedJobRegistry(queue=q) self.assertTrue(job in failed_job_registry) + @mock.patch('rq.worker.logger.error') + def test_deserializing_failure_is_handled(self, mock_logger_error): + """ + Test that exceptions are properly handled for a job that fails to + deserialize. + """ + q = Queue() + self.assertEqual(q.count, 0) + + # as in test_work_is_unreadable(), we create a fake bad job + job = Job.create(func=div_by_zero, args=(3,), origin=q.name) + job.save() + + # setting data to b'' ensures that pickling will completely fail + job_data = job.data + invalid_data = job_data.replace(b'div_by_zero', b'') + assert job_data != invalid_data + self.testconn.hset(job.key, 'data', zlib.compress(invalid_data)) + + # We use the low-level internal function to enqueue any data (bypassing + # validity checks) + q.push_job_id(job.id) + self.assertEqual(q.count, 1) + + # Now we try to run the job... + w = Worker([q]) + job, queue = w.dequeue_job_and_maintain_ttl(10) + w.perform_job(job, queue) + + # An exception should be logged here at ERROR level + self.assertIn("Traceback", mock_logger_error.call_args[0][0]) + def test_heartbeat(self): """Heartbeat saves last_heartbeat""" q = Queue()