mirror of
https://github.com/correl/typesafe-monads.git
synced 2024-11-21 19:18:42 +00:00
Add Reader
This commit is contained in:
parent
a4a46a46c7
commit
b798eb74ba
3 changed files with 136 additions and 0 deletions
|
@ -4,3 +4,4 @@ from .monad import Monad
|
|||
from .list import List
|
||||
from .maybe import Maybe, Just, Nothing
|
||||
from .result import Result, Ok, Err
|
||||
from .reader import Reader
|
||||
|
|
60
monads/reader.py
Normal file
60
monads/reader.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from __future__ import annotations
|
||||
from functools import reduce, update_wrapper
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
from .monad import Monad
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
Env = TypeVar("Env")
|
||||
F = Callable[[Env], T]
|
||||
|
||||
|
||||
class Reader(Monad[T], Generic[Env, T]):
|
||||
def __init__(self, function: F) -> None:
|
||||
update_wrapper(self, function)
|
||||
self.function = function
|
||||
|
||||
def __call__(self, environment: Env) -> T:
|
||||
return self.function(environment)
|
||||
|
||||
@classmethod
|
||||
def pure(cls, value: T) -> Reader[Env, T]:
|
||||
f: F = lambda x: value
|
||||
return cls(f)
|
||||
|
||||
def map(self, function: Callable[[T], S]) -> Reader[Env, S]:
|
||||
f: Callable[[Env], S] = lambda x: function(self.function(x))
|
||||
return Reader(f)
|
||||
|
||||
def apply(self, r: Reader[Env, Callable[[T], S]]) -> Reader[Env, S]:
|
||||
f: Callable[[Env], S] = lambda x: r.function(x)(self(x))
|
||||
return Reader(f)
|
||||
|
||||
def bind(self, function: Callable[[T], Reader[Env, S]]) -> Reader[Env, S]:
|
||||
f: Callable[[Env], S] = lambda x: function(self.function(x))(x)
|
||||
return Reader(f)
|
||||
|
||||
def __eq__(self, other: object): # pragma: no cover
|
||||
return isinstance(other, Reader) and self.function == other.function
|
||||
|
||||
def __repr__(self): # pragma: no cover
|
||||
module = self.function.__module__
|
||||
name = self.function.__name__
|
||||
return f"<Reader {module}.{name}>"
|
||||
|
||||
|
||||
class Curried(Reader[Env, T]):
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
def curry(f: Callable):
|
||||
def wrapped(args, remaining):
|
||||
if remaining == 0:
|
||||
return f(*args)
|
||||
else:
|
||||
curried = lambda x: wrapped(args + [x], remaining - 1)
|
||||
return Curried(update_wrapper(curried, f))
|
||||
|
||||
return wrapped([], f.__code__.co_argcount)
|
75
tests/test_reader.py
Normal file
75
tests/test_reader.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import pytest # type: ignore
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from monads import Functor, Applicative, Reader
|
||||
from monads.reader import curry
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
def test_functor_identity() -> None:
|
||||
m: Reader = Reader.pure(3)
|
||||
identity: Callable[[T], T] = lambda x: x
|
||||
assert m(0) == m.map(identity)(0)
|
||||
|
||||
|
||||
def test_functor_associativity() -> None:
|
||||
f: Callable[[int], int] = lambda x: x + 1
|
||||
g: Callable[[int], str] = lambda x: str(x)
|
||||
m: Reader[int, int] = Reader.pure(3)
|
||||
assert m.map(lambda x: g(f(x)))(0) == m.map(f).map(g)(0)
|
||||
|
||||
|
||||
def test_applicative_fmap_using_ap() -> None:
|
||||
f: Callable[[int], int] = lambda x: x + 1
|
||||
m: Reader[int, int] = Reader.pure(3)
|
||||
assert m.map(f)(0) == m.apply(Reader.pure(f))(0)
|
||||
|
||||
|
||||
def test_monad_bind() -> None:
|
||||
expected: Reader[int, int] = Reader.pure(2)
|
||||
m: Reader[int, int] = Reader.pure(1)
|
||||
assert expected(0) == m.bind(lambda x: Reader.pure(x + 1))(0)
|
||||
|
||||
|
||||
def test_monad_left_identity() -> None:
|
||||
n: int = 3
|
||||
|
||||
def f(n: int) -> Reader[int, int]:
|
||||
return Reader.pure(n * 3)
|
||||
|
||||
m: Reader[int, int] = Reader.pure(n)
|
||||
assert m.bind(f)(0) == f(n)(0)
|
||||
|
||||
|
||||
def test_monad_right_identity() -> None:
|
||||
m: Reader[int, int] = Reader.pure(3)
|
||||
assert m(0) == m.bind(lambda x: Reader.pure(x))(0)
|
||||
|
||||
|
||||
def test_monad_associativity() -> None:
|
||||
m: Reader[int, int] = Reader.pure(3)
|
||||
|
||||
def f(n: int) -> Reader[int, int]:
|
||||
return Reader.pure(n * 3)
|
||||
|
||||
def g(n: int) -> Reader[int, int]:
|
||||
return Reader.pure(n + 5)
|
||||
|
||||
assert m.bind(f).bind(g)(0) == m.bind(lambda x: f(x).bind(g))(0)
|
||||
|
||||
|
||||
def test_curry() -> None:
|
||||
def add3(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert add3(1, 2, 3) == curry(add3)(1)(2)(3)
|
||||
|
||||
|
||||
def test_call_curried_function_with_multiple_arguments() -> None:
|
||||
@curry
|
||||
def add3(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert add3(1, 2)(3) == add3(1, 2, 3)
|
Loading…
Reference in a new issue