You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

323 lines
9.7 KiB
Python

# -*- coding: utf-8 -*-
"""
Miscellaneous helper functions.
The formatter for ANSI colored console output is heavily based on Pygments
terminal colorizing code, originally by Georg Brandl.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import calendar
import datetime
import importlib
import logging
import numbers
import sys
from collections.abc import Iterable
from redis.exceptions import ResponseError
from .compat import as_text, is_python_version, string_types
from .exceptions import TimeoutFormatError
logger = logging.getLogger(__name__)
class _Colorizer:
def __init__(self):
esc = "\x1b["
self.codes = {}
self.codes[""] = ""
self.codes["reset"] = esc + "39;49;00m"
self.codes["bold"] = esc + "01m"
self.codes["faint"] = esc + "02m"
self.codes["standout"] = esc + "03m"
self.codes["underline"] = esc + "04m"
self.codes["blink"] = esc + "05m"
self.codes["overline"] = esc + "06m"
dark_colors = ["black", "darkred", "darkgreen", "brown", "darkblue",
"purple", "teal", "lightgray"]
light_colors = ["darkgray", "red", "green", "yellow", "blue",
"fuchsia", "turquoise", "white"]
x = 30
for d, l in zip(dark_colors, light_colors):
self.codes[d] = esc + "%im" % x
self.codes[l] = esc + "%i;01m" % x
x += 1
del d, l, x
self.codes["darkteal"] = self.codes["turquoise"]
self.codes["darkyellow"] = self.codes["brown"]
self.codes["fuscia"] = self.codes["fuchsia"]
self.codes["white"] = self.codes["bold"]
if hasattr(sys.stdout, "isatty"):
self.notty = not sys.stdout.isatty()
else:
self.notty = True
def reset_color(self):
return self.codes["reset"]
def colorize(self, color_key, text):
if self.notty:
return text
else:
return self.codes[color_key] + text + self.codes["reset"]
colorizer = _Colorizer()
def make_colorizer(color):
"""Creates a function that colorizes text with the given color.
For example:
green = make_colorizer('darkgreen')
red = make_colorizer('red')
Then, you can use:
print "It's either " + green('OK') + ' or ' + red('Oops')
"""
def inner(text):
return colorizer.colorize(color, text)
return inner
class ColorizingStreamHandler(logging.StreamHandler):
levels = {
logging.WARNING: make_colorizer('darkyellow'),
logging.ERROR: make_colorizer('darkred'),
logging.CRITICAL: make_colorizer('darkred'),
}
def __init__(self, exclude=None, *args, **kwargs):
self.exclude = exclude
super(ColorizingStreamHandler, self).__init__(*args, **kwargs)
@property
def is_tty(self):
isatty = getattr(self.stream, 'isatty', None)
return isatty and isatty()
def format(self, record):
message = logging.StreamHandler.format(self, record)
if self.is_tty:
colorize = self.levels.get(record.levelno, lambda x: x)
# Don't colorize any traceback
parts = message.split('\n', 1)
parts[0] = " ".join([parts[0].split(" ", 1)[0], colorize(parts[0].split(" ", 1)[1])])
message = '\n'.join(parts)
return message
def import_attribute(name):
"""Return an attribute from a dotted path name (e.g. "path.to.func")."""
name_bits = name.split('.')
module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]]
module = None
# When the attribute we look for is a staticmethod, module name in its
# dotted path is not the last-before-end word
# E.g.: package_a.package_b.module_a.ClassA.my_static_method
# Thus we remove the bits from the end of the name until we can import it
#
# Sometimes the failure during importing is due to a genuine coding error in the imported module
# In this case, the exception is logged as a warning for ease of debugging.
# The above logic will apply anyways regardless of the cause of the import error.
while len(module_name_bits):
try:
module_name = '.'.join(module_name_bits)
module = importlib.import_module(module_name)
break
except ImportError:
logger.warning("Import error for '%s'" % module_name, exc_info=True)
attribute_bits.insert(0, module_name_bits.pop())
if module is None:
raise ValueError('Invalid attribute name: %s' % name)
attribute_name = '.'.join(attribute_bits)
if hasattr(module, attribute_name):
return getattr(module, attribute_name)
# staticmethods
attribute_name = attribute_bits.pop()
attribute_owner_name = '.'.join(attribute_bits)
attribute_owner = getattr(module, attribute_owner_name)
if not hasattr(attribute_owner, attribute_name):
raise ValueError('Invalid attribute name: %s' % name)
return getattr(attribute_owner, attribute_name)
def utcnow():
return datetime.datetime.utcnow()
_TIMESTAMP_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
def utcformat(dt):
return dt.strftime(as_text(_TIMESTAMP_FORMAT))
def utcparse(string):
try:
return datetime.datetime.strptime(string, _TIMESTAMP_FORMAT)
except ValueError:
# This catches any jobs remain with old datetime format
return datetime.datetime.strptime(string, '%Y-%m-%dT%H:%M:%SZ')
def first(iterable, default=None, key=None):
"""
Return first element of `iterable` that evaluates true, else return None
(or an optional default value).
>>> first([0, False, None, [], (), 42])
42
>>> first([0, False, None, [], ()]) is None
True
>>> first([0, False, None, [], ()], default='ohai')
'ohai'
>>> import re
>>> m = first(re.match(regex, 'abc') for regex in ['b.*', 'a(.*)'])
>>> m.group(1)
'bc'
The optional `key` argument specifies a one-argument predicate function
like that used for `filter()`. The `key` argument, if supplied, must be
in keyword form. For example:
>>> first([1, 1, 3, 4, 5], key=lambda x: x % 2 == 0)
4
"""
if key is None:
for el in iterable:
if el:
return el
else:
for el in iterable:
if key(el):
return el
return default
def is_nonstring_iterable(obj):
"""Returns whether the obj is an iterable, but not a string"""
return isinstance(obj, Iterable) and not isinstance(obj, string_types)
def ensure_list(obj):
"""
When passed an iterable of objects, does nothing, otherwise, it returns
a list with just that object in it.
"""
return obj if is_nonstring_iterable(obj) else [obj]
def current_timestamp():
"""Returns current UTC timestamp"""
return calendar.timegm(datetime.datetime.utcnow().utctimetuple())
def backend_class(holder, default_name, override=None):
"""Get a backend class using its default attribute name or an override"""
if override is None:
return getattr(holder, default_name)
elif isinstance(override, string_types):
return import_attribute(override)
else:
return override
def str_to_date(date_str):
if not date_str:
return
else:
return utcparse(date_str.decode())
def parse_timeout(timeout):
"""Transfer all kinds of timeout format to an integer representing seconds"""
if not isinstance(timeout, numbers.Integral) and timeout is not None:
try:
timeout = int(timeout)
except ValueError:
digit, unit = timeout[:-1], (timeout[-1:]).lower()
unit_second = {'d': 86400, 'h': 3600, 'm': 60, 's': 1}
try:
timeout = int(digit) * unit_second[unit]
except (ValueError, KeyError):
raise TimeoutFormatError('Timeout must be an integer or a string representing an integer, or '
'a string with format: digits + unit, unit can be "d", "h", "m", "s", '
'such as "1h", "23m".')
return timeout
def get_version(connection):
"""
Returns tuple of Redis server version.
This function also correctly handles 4 digit redis server versions.
"""
try:
return tuple(int(i) for i in connection.info("server")["redis_version"].split('.')[:3])
except ResponseError: # fakeredis doesn't implement Redis' INFO command
return (5, 0, 9)
def ceildiv(a, b):
"""Ceiling division. Returns the ceiling of the quotient of a division operation"""
return -(-a // b)
def split_list(a_list, segment_size):
"""
Splits a list into multiple smaller lists having size `segment_size`
"""
for i in range(0, len(a_list), segment_size):
yield a_list[i:i + segment_size]
def truncate_long_string(data, max_length=None):
"""Truncate arguments with representation longer than max_length"""
if max_length is None:
return data
return (data[:max_length] + '...') if len(data) > max_length else data
def get_call_string(func_name, args, kwargs, max_length=None):
"""Returns a string representation of the call, formatted as a regular
Python function invocation statement. If max_length is not None, truncate
arguments with representation longer than max_length.
"""
if func_name is None:
return None
arg_list = [as_text(truncate_long_string(repr(arg), max_length)) for arg in args]
kwargs = ['{0}={1}'.format(k, as_text(truncate_long_string(repr(v), max_length))) for k, v in kwargs.items()]
arg_list += sorted(kwargs)
args = ', '.join(arg_list)
return '{0}({1})'.format(func_name, args)