mirror of
https://github.com/correl/typesafe-monads.git
synced 2024-12-01 03:00:14 +00:00
61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
|
from __future__ import annotations
|
||
|
from functools import reduce, update_wrapper
|
||
|
from typing import Any, Callable, Generic, TypeVar
|
||
|
|
||
|
from .monad import Monad
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
S = TypeVar("S")
|
||
|
Env = TypeVar("Env")
|
||
|
F = Callable[[Env], T]
|
||
|
|
||
|
|
||
|
class Reader(Monad[T], Generic[Env, T]):
|
||
|
def __init__(self, function: F) -> None:
|
||
|
update_wrapper(self, function)
|
||
|
self.function = function
|
||
|
|
||
|
def __call__(self, environment: Env) -> T:
|
||
|
return self.function(environment)
|
||
|
|
||
|
@classmethod
|
||
|
def pure(cls, value: T) -> Reader[Env, T]:
|
||
|
f: F = lambda x: value
|
||
|
return cls(f)
|
||
|
|
||
|
def map(self, function: Callable[[T], S]) -> Reader[Env, S]:
|
||
|
f: Callable[[Env], S] = lambda x: function(self.function(x))
|
||
|
return Reader(f)
|
||
|
|
||
|
def apply(self, r: Reader[Env, Callable[[T], S]]) -> Reader[Env, S]:
|
||
|
f: Callable[[Env], S] = lambda x: r.function(x)(self(x))
|
||
|
return Reader(f)
|
||
|
|
||
|
def bind(self, function: Callable[[T], Reader[Env, S]]) -> Reader[Env, S]:
|
||
|
f: Callable[[Env], S] = lambda x: function(self.function(x))(x)
|
||
|
return Reader(f)
|
||
|
|
||
|
def __eq__(self, other: object): # pragma: no cover
|
||
|
return isinstance(other, Reader) and self.function == other.function
|
||
|
|
||
|
def __repr__(self): # pragma: no cover
|
||
|
module = self.function.__module__
|
||
|
name = self.function.__name__
|
||
|
return f"<Reader {module}.{name}>"
|
||
|
|
||
|
|
||
|
class Curried(Reader[Env, T]):
|
||
|
def __call__(self, *args):
|
||
|
return reduce(lambda f, x: f(x), args, self.function)
|
||
|
|
||
|
|
||
|
def curry(f: Callable):
|
||
|
def wrapped(args, remaining):
|
||
|
if remaining == 0:
|
||
|
return f(*args)
|
||
|
else:
|
||
|
curried = lambda x: wrapped(args + [x], remaining - 1)
|
||
|
return Curried(update_wrapper(curried, f))
|
||
|
|
||
|
return wrapped([], f.__code__.co_argcount)
|