diff --git a/monads/functor.py b/monads/functor.py index 6634698..de76fa2 100644 --- a/monads/functor.py +++ b/monads/functor.py @@ -8,3 +8,5 @@ S = TypeVar("S") class Functor(Generic[T]): def map(self, function: Callable[[T], S]) -> Functor[S]: # pragma: no cover raise NotImplementedError + + __mul__ = __rmul__ = map diff --git a/monads/list.py b/monads/list.py index b929a41..c23523d 100644 --- a/monads/list.py +++ b/monads/list.py @@ -34,3 +34,5 @@ class List(Monad[T], Monoidal[list]): return List(self.value + other.value) __add__ = mappend + __mul__ = __rmul__ = map + __rshift__ = bind diff --git a/monads/monad.py b/monads/monad.py index adcd7ea..c2db2cf 100644 --- a/monads/monad.py +++ b/monads/monad.py @@ -15,3 +15,5 @@ class Monad(Applicative[T]): # https://github.com/python/mypy/issues/1317 def bind(self, function: Callable[[T], Any]) -> Monad[S]: # pragma: no cover raise NotImplementedError + + __rshift__ = bind diff --git a/monads/reader.py b/monads/reader.py index 86ee73a..2375bd2 100644 --- a/monads/reader.py +++ b/monads/reader.py @@ -43,6 +43,9 @@ class Reader(Monad[T], Generic[Env, T]): name = self.function.__name__ return f"" + __mul__ = __rmul__ = map + __rshift__ = bind + class Curried(Reader[Env, T]): def __call__(self, *args): diff --git a/tests/test_functors.py b/tests/test_functors.py index 88a074f..b82d453 100644 --- a/tests/test_functors.py +++ b/tests/test_functors.py @@ -19,3 +19,15 @@ def test_associativity(monad) -> None: g: Callable[[int], str] = lambda x: str(x) m: Functor = monad.pure(3) assert m.map(lambda x: g(f(x))) == m.map(f).map(g) + + +def test_map_mul_operator(monad) -> None: + m: Functor = monad.pure(3) + identity: Callable[[T], T] = lambda x: x + assert m.map(identity) == m * identity + + +def test_map_rmul_operator(monad) -> None: + m: Functor = monad.pure(3) + identity: Callable[[T], T] = lambda x: x + assert m.map(identity) == identity * m diff --git a/tests/test_monads.py b/tests/test_monads.py index 16036f6..2bea71f 100644 --- a/tests/test_monads.py +++ b/tests/test_monads.py @@ -1,3 +1,5 @@ +from typing import Callable + import pytest # type: ignore from monads import Monad @@ -9,6 +11,12 @@ def test_bind(monad) -> None: assert expected == monad.pure(1).bind(lambda x: monad.pure(x + 1)) +def test_bind_rshift_operator(monad) -> None: + m: Monad[int] = monad.pure(2) + f: Callable[[int], Monad[int]] = lambda x: monad.pure(x + 1) + assert m.bind(f) == m >> f + + def test_left_identity(monad) -> None: n: int = 3 diff --git a/tests/test_monoids.py b/tests/test_monoids.py index 9d217d4..927a11c 100644 --- a/tests/test_monoids.py +++ b/tests/test_monoids.py @@ -28,6 +28,12 @@ def construct(constructor: Constructor, value: Any) -> Monoid: return cls(builder(value)) +def test_mappend_add_operator(constructor: Constructor) -> None: + a: Monoid = construct(constructor, 1) + b: Monoid = construct(constructor, 2) + assert a.mappend(b) == a + b + + def test_associative(constructor: Constructor) -> None: a: Monoid = construct(constructor, 1) b: Monoid = construct(constructor, 2) diff --git a/tests/test_reader.py b/tests/test_reader.py index b9459b2..cc0a5b1 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -21,6 +21,18 @@ def test_functor_associativity() -> None: assert m.map(lambda x: g(f(x)))(0) == m.map(f).map(g)(0) +def test_functor_map_mul_operator() -> None: + m: Reader = Reader.pure(3) + identity: Callable[[T], T] = lambda x: x + assert m.map(identity)(0) == (m * identity)(0) + + +def test_functor_map_rmul_operator() -> None: + m: Reader = Reader.pure(3) + identity: Callable[[T], T] = lambda x: x + assert m.map(identity)(0) == (identity * m)(0) + + def test_applicative_fmap_using_ap() -> None: f: Callable[[int], int] = lambda x: x + 1 m: Reader[int, int] = Reader.pure(3) @@ -33,6 +45,12 @@ def test_monad_bind() -> None: assert expected(0) == m.bind(lambda x: Reader.pure(x + 1))(0) +def test_monad_bind_rshift_operator() -> None: + m: Reader[int, int] = Reader.pure(2) + f: Callable[[int], Reader[int, int]] = lambda x: Reader.pure(x + 1) + assert m.bind(f)(0) == (m >> f)(0) + + def test_monad_left_identity() -> None: n: int = 3