|
|
|
@ -18,12 +18,14 @@ def fix_return_type(func):
|
|
|
|
|
return _inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATCHED_METHODS = ['_setex', '_lrem', '_zadd', '_pipeline', '_ttl']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_connection(connection):
|
|
|
|
|
if not isinstance(connection, StrictRedis):
|
|
|
|
|
raise ValueError('A StrictRedis or Redis connection is required.')
|
|
|
|
|
|
|
|
|
|
# Don't patch already patches objects
|
|
|
|
|
PATCHED_METHODS = ['_setex', '_lrem', '_zadd', '_pipeline', '_ttl']
|
|
|
|
|
if all([hasattr(connection, attr) for attr in PATCHED_METHODS]):
|
|
|
|
|
return connection
|
|
|
|
|
|
|
|
|
@ -35,6 +37,7 @@ def patch_connection(connection):
|
|
|
|
|
connection._ttl = fix_return_type(partial(StrictRedis.ttl, connection))
|
|
|
|
|
if hasattr(connection, 'pttl'):
|
|
|
|
|
connection._pttl = fix_return_type(partial(StrictRedis.pttl, connection))
|
|
|
|
|
|
|
|
|
|
elif isinstance(connection, StrictRedis):
|
|
|
|
|
connection._setex = connection.setex
|
|
|
|
|
connection._lrem = connection.lrem
|
|
|
|
|