2018-10-12 21:02:56 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
from functools import reduce, update_wrapper
|
|
|
|
from typing import Any, Callable, Generic, TypeVar
|
|
|
|
|
|
|
|
from .monad import Monad
|
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
S = TypeVar("S")
|
|
|
|
Env = TypeVar("Env")
|
|
|
|
F = Callable[[Env], T]
|
|
|
|
|
|
|
|
|
|
|
|
class Reader(Monad[T], Generic[Env, T]):
|
|
|
|
def __init__(self, function: F) -> None:
|
|
|
|
update_wrapper(self, function)
|
|
|
|
self.function = function
|
|
|
|
|
|
|
|
def __call__(self, environment: Env) -> T:
|
|
|
|
return self.function(environment)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def pure(cls, value: T) -> Reader[Env, T]:
|
|
|
|
f: F = lambda x: value
|
|
|
|
return cls(f)
|
|
|
|
|
|
|
|
def map(self, function: Callable[[T], S]) -> Reader[Env, S]:
|
|
|
|
f: Callable[[Env], S] = lambda x: function(self.function(x))
|
|
|
|
return Reader(f)
|
|
|
|
|
|
|
|
def apply(self, r: Reader[Env, Callable[[T], S]]) -> Reader[Env, S]:
|
|
|
|
f: Callable[[Env], S] = lambda x: r.function(x)(self(x))
|
|
|
|
return Reader(f)
|
|
|
|
|
|
|
|
def bind(self, function: Callable[[T], Reader[Env, S]]) -> Reader[Env, S]:
|
|
|
|
f: Callable[[Env], S] = lambda x: function(self.function(x))(x)
|
|
|
|
return Reader(f)
|
|
|
|
|
|
|
|
def __eq__(self, other: object): # pragma: no cover
|
|
|
|
return isinstance(other, Reader) and self.function == other.function
|
|
|
|
|
|
|
|
def __repr__(self): # pragma: no cover
|
|
|
|
module = self.function.__module__
|
|
|
|
name = self.function.__name__
|
|
|
|
return f"<Reader {module}.{name}>"
|
|
|
|
|
2018-12-06 18:53:07 +00:00
|
|
|
__mul__ = __rmul__ = map
|
|
|
|
__rshift__ = bind
|
|
|
|
|
2018-10-12 21:02:56 +00:00
|
|
|
|
|
|
|
class Curried(Reader[Env, T]):
|
|
|
|
def __call__(self, *args):
|
|
|
|
return reduce(lambda f, x: f(x), args, self.function)
|
|
|
|
|
|
|
|
|
|
|
|
def curry(f: Callable):
|
|
|
|
def wrapped(args, remaining):
|
|
|
|
if remaining == 0:
|
|
|
|
return f(*args)
|
|
|
|
else:
|
|
|
|
curried = lambda x: wrapped(args + [x], remaining - 1)
|
|
|
|
return Curried(update_wrapper(curried, f))
|
|
|
|
|
|
|
|
return wrapped([], f.__code__.co_argcount)
|