|
Description:
This recipe allows nice and clean validation for method parameters/return values. It uses function annotations available in Python 3000 for the actual signature specification.
Source: Text Source
__all__ = [ "typecheck", "optional", "with_attr", "by_regex",
"nothing", "callable", "list_of", "dict_of",
"TypeCheckError", "TypeCheckSpecificationError",
"InputParameterError", "ReturnValueError" ]
from inspect import getfullargspec, isclass
from functools import update_wrapper, reduce
from re import compile
callable = lambda x: hasattr(x, "__call__")
nothing = lambda x: x is None
class TypeCheckError(Exception): pass
class TypeCheckSpecificationError(Exception): pass
class InputParameterError(TypeCheckError): pass
class ReturnValueError(TypeCheckError): pass
class Checker(object):
class NoValue:
def __str__(self):
return "<no value>"
no_value = NoValue()
_registered = []
@classmethod
def register(cls, predicate, factory):
cls._registered.append((predicate, factory))
@classmethod
def create(cls, value):
if isinstance(value, cls):
return value
for predicate, factory in cls._registered:
if predicate(value):
return factory(value)
else:
return None
class TypeChecker(Checker):
def __init__(self, cls):
self._cls = cls
def check(self, value):
return isinstance(value, self._cls)
Checker.register(isclass, TypeChecker)
iterable = lambda x: hasattr(x, "__iter__")
class IterableChecker(Checker):
def __init__(self, cont):
self._cls = type(cont)
self._chks = tuple(Checker.create(x) for x in iter(cont))
def check(self, value):
if not iterable(value):
return False
vals = tuple(iter(value))
return isinstance(value, self._cls) and len(self._chks) == len(vals) and \
reduce(lambda r, c_v: r and c_v[0].check(c_v[1]), zip(self._chks, vals), True)
Checker.register(iterable, IterableChecker)
class CallableChecker(Checker):
def __init__(self, func):
self._func = func
def check(self, value):
return bool(self._func(value))
Checker.register(callable, CallableChecker)
class OptionalChecker(Checker):
def __init__(self, check):
self._check = Checker.create(check)
def check(self, value):
return value is Checker.no_value or value is None or self._check.check(value)
optional = OptionalChecker
class WithAttrChecker(Checker):
def __init__(self, *attrs):
self._attrs = attrs
def check(self, value):
for attr in self._attrs:
if not hasattr(value, attr):
return False
else:
return True
with_attr = WithAttrChecker
class ByRegexChecker(Checker):
def __init__(self, regex):
self._regex = compile(regex)
def check(self, value):
return isinstance(value, str) and self._regex.match(value) is not None
by_regex = ByRegexChecker
class ListOfChecker(Checker):
def __init__(self, check):
self._check = Checker.create(check)
def check(self, value):
return isinstance(value, list) and \
reduce(lambda r, v: r and self._check.check(v), value, True)
list_of = ListOfChecker
class DictOfChecker(Checker):
def __init__(self, key_check, value_check):
self._key_check = Checker.create(key_check)
self._value_check = Checker.create(value_check)
def check(self, value):
return isinstance(value, dict) and \
reduce(lambda r, t: r and self._key_check.check(t[0]) and \
self._value_check.check(t[1]),
value.items(), True)
dict_of = DictOfChecker
def typecheck(method):
argspec = getfullargspec(method)
default_arg_count = len(argspec.defaults or [])
non_default_arg_count = len(argspec.args) - default_arg_count
method_name = method.__name__
arg_checkers = [None] * len(argspec.args)
kwarg_checkers = {}
return_checker = None
kwarg_defaults = argspec.kwdefaults or {}
for n, v in argspec.annotations.items():
checker = Checker.create(v)
if checker is None:
raise TypeCheckSpecificationError("invalid typecheck for {0}".format(n))
if n in argspec.kwonlyargs:
if n in kwarg_defaults and \
not checker.check(kwarg_defaults[n]):
raise TypeCheckSpecificationError("the default value for {0} is incompatible "
"with its typecheck".format(n))
kwarg_checkers[n] = checker
elif n == "return":
return_checker = checker
else:
i = argspec.args.index(n)
if i >= non_default_arg_count and \
not checker.check(argspec.defaults[i - non_default_arg_count]):
raise TypeCheckSpecificationError("the default value for {0} is incompatible "
"with its typecheck".format(n))
arg_checkers[i] = (n, checker)
def typecheck_invocation_proxy(*args, **kwargs):
for check, arg in zip(arg_checkers, args):
if check is not None:
arg_name, checker = check
if not checker.check(arg):
raise InputParameterError("{0}() has got an incompatible value "
"for {1}: {2}".format(method_name, arg_name,
str(arg) == "" and "''" or arg))
for arg_name, checker in kwarg_checkers.items():
kwarg = kwargs.get(arg_name, Checker.no_value)
if not checker.check(kwarg):
raise InputParameterError("{0}() has got an incompatible value "
"for {1}: {2}".format(method_name, arg_name,
str(kwarg) == "" and "''" or kwarg))
result = method(*args, **kwargs)
if return_checker is not None and not return_checker.check(result):
raise ReturnValueError("{0}() has returned an incompatible "
"value: {1}".format(method_name, str(result) == "" and "''" or result))
return result
update_wrapper(typecheck_invocation_proxy, method)
return typecheck_invocation_proxy
Discussion:
With Python 3000's function annotations method signature type checking is much cleaner than with Python 2.x (see at http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/426123).
First, it puts the signature-related piece of syntax exactly where it belongs - near the parameter. Second, you don't have to add checking to all the parameters, but just to those which require checking. Third, it plays nicely with default values.
The typical usage for this decorator would be something like this:
@typecheck
def foo(i: int) -> bool:
return a > 0
@typecheck
def foo(*, s = by_regex("^[0-9]+$")) -> int:
return int(s)
@typecheck
def foo(*, k: optional((int, int)) = (1, 2)):
return k[0] ** k[1]
@typecheck
def foo(f, *, k: optional(list_of(callable)) = [lambda x: x]):
return reduce(lambda r, e: r(e), k, f)
@typecheck
def foo(k: str, d: dict_of(str, str)) -> str:
return d[k]
This recipe is extensible with callable predicates, such as
is_even = lambda x: x % 2 == 0
def multiply_by_2(i: int) -> is_even:
return i * 2
|