From 08cb311c554bdfd3ccaa846b6e3fbd1971b32310 Mon Sep 17 00:00:00 2001 From: Cyril Chapellier Date: Mon, 1 May 2023 06:20:40 +0200 Subject: [PATCH] [Results] Allow unserializable return values (#1888) * fix: allow unserializable return values * fix: review comments --- rq/defaults.py | 6 ++++++ rq/job.py | 4 ++-- rq/results.py | 7 ++++++- tests/test_results.py | 18 ++++++++++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/rq/defaults.py b/rq/defaults.py index 2a3d57a..3744c12 100644 --- a/rq/defaults.py +++ b/rq/defaults.py @@ -93,4 +93,10 @@ https://docs.python.org/3/library/logging.html#logrecord-attributes DEFAULT_DEATH_PENALTY_CLASS = 'rq.timeouts.UnixSignalDeathPenalty' """ The path for the default Death Penalty class to use. Defaults to the `UnixSignalDeathPenalty` class within the `rq.timeouts` module +""" + + +UNSERIALIZABLE_RETURN_VALUE_PAYLOAD = 'Unserializable return value' +""" The value that we store in the job's _result property or in the Result's return_value +in case the return value of the actual job is not serializable """ \ No newline at end of file diff --git a/rq/job.py b/rq/job.py index a0574e6..61bc9aa 100644 --- a/rq/job.py +++ b/rq/job.py @@ -11,7 +11,7 @@ from redis import WatchError from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, Type from uuid import uuid4 -from .defaults import CALLBACK_TIMEOUT +from .defaults import CALLBACK_TIMEOUT, UNSERIALIZABLE_RETURN_VALUE_PAYLOAD from .timeouts import JobTimeoutException, BaseDeathPenalty if TYPE_CHECKING: @@ -887,7 +887,7 @@ class Job: try: self._result = self.serializer.loads(result) except Exception: - self._result = "Unserializable return value" + self._result = UNSERIALIZABLE_RETURN_VALUE_PAYLOAD self.timeout = parse_timeout(obj.get('timeout')) if obj.get('timeout') else None self.result_ttl = int(obj.get('result_ttl')) if obj.get('result_ttl') else None self.failure_ttl = int(obj.get('failure_ttl')) if obj.get('failure_ttl') else None diff --git a/rq/results.py b/rq/results.py index 55ee971..fdbb763 100644 --- a/rq/results.py +++ b/rq/results.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from enum import Enum from redis import Redis +from .defaults import UNSERIALIZABLE_RETURN_VALUE_PAYLOAD from .utils import decode_redis_hash from .job import Job from .serializers import resolve_serializer @@ -181,7 +182,11 @@ class Result: if self.exc_string is not None: data['exc_string'] = b64encode(zlib.compress(self.exc_string.encode())).decode() - serialized = self.serializer.dumps(self.return_value) + try: + serialized = self.serializer.dumps(self.return_value) + except: # noqa + serialized = self.serializer.dumps(UNSERIALIZABLE_RETURN_VALUE_PAYLOAD) + if self.return_value is not None: data['return_value'] = b64encode(serialized).decode() diff --git a/tests/test_results.py b/tests/test_results.py index 9bc1b9e..4286cec 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -1,4 +1,5 @@ import unittest +import tempfile from datetime import timedelta from unittest.mock import patch, PropertyMock @@ -7,6 +8,7 @@ from redis import Redis from tests import RQTestCase +from rq.defaults import UNSERIALIZABLE_RETURN_VALUE_PAYLOAD from rq.job import Job from rq.queue import Queue from rq.registry import StartedJobRegistry @@ -236,3 +238,19 @@ class TestScheduledJobRegistry(RQTestCase): Result.create(job, Result.Type.SUCCESSFUL, ttl=0, return_value=1) self.assertIsNone(job.return_value()) + + def test_job_return_value_unserializable(self): + """Test job.return_value when it is not serializable""" + queue = Queue(connection=self.connection, result_ttl=0) + job = queue.enqueue(say_hello) + + # Returns None when there's no result + self.assertIsNone(job.return_value()) + + # tempfile.NamedTemporaryFile() is not picklable + Result.create(job, Result.Type.SUCCESSFUL, ttl=10, return_value=tempfile.NamedTemporaryFile()) + self.assertEqual(job.return_value(), UNSERIALIZABLE_RETURN_VALUE_PAYLOAD) + self.assertEqual(Result.count(job), 1) + + Result.create(job, Result.Type.SUCCESSFUL, ttl=10, return_value=1) + self.assertEqual(Result.count(job), 2)