diff --git a/rq/job.py b/rq/job.py index 6075132..126a863 100644 --- a/rq/job.py +++ b/rq/job.py @@ -8,6 +8,7 @@ import pickle import warnings import zlib +import asyncio from collections.abc import Iterable from distutils.version import StrictVersion from functools import partial @@ -720,7 +721,12 @@ class Job(object): pipeline.hmset(self.key, mapping) def _execute(self): - return self.func(*self.args, **self.kwargs) + result = self.func(*self.args, **self.kwargs) + if asyncio.iscoroutine(result): + loop = asyncio.get_event_loop() + coro_result = loop.run_until_complete(result) + return coro_result + return result def get_ttl(self, default_ttl=None): """Returns ttl for a job that determines how long a job will be diff --git a/tests/fixtures.py b/tests/fixtures.py index b2d4af1..82e98bc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -32,6 +32,11 @@ def say_hello(name=None): return 'Hi there, %s!' % (name,) +async def say_hello_async(name=None): + """A async job with a single argument and a return value.""" + return say_hello(name) + + def say_hello_unicode(name=None): """A job with a single argument and a return value.""" return text_type(say_hello(name)) # noqa diff --git a/tests/test_job.py b/tests/test_job.py index 8b921d7..fa7eff1 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -748,6 +748,21 @@ class TestJob(RQTestCase): self.assertRaises(TypeError, queue.enqueue, fixtures.say_hello, job_id=1234) + def test_create_job_with_async(self): + """test creating jobs with async function""" + queue = Queue(connection=self.testconn) + + async_job = queue.enqueue(fixtures.say_hello_async, job_id="async_job") + sync_job = queue.enqueue(fixtures.say_hello, job_id="sync_job") + + self.assertEqual(async_job.id, "async_job") + self.assertEqual(sync_job.id, "sync_job") + + async_task_result = async_job.perform() + sync_task_result = sync_job.perform() + + self.assertEqual(sync_task_result, async_task_result) + def test_get_call_string_unicode(self): """test call string with unicode keyword arguments""" queue = Queue(connection=self.testconn)