Test that each monad respects monad laws

This commit is contained in:
Correl Roush 2018-10-11 12:15:52 -04:00
parent de53c59732
commit d38f617ba4
5 changed files with 54 additions and 1 deletions

View file

@ -1,2 +1,4 @@
from .functor import Functor
from .monad import Monad
from .maybe import Maybe, Just, Nothing
from .result import Result, Ok, Err

View file

@ -43,6 +43,9 @@ class Just(Maybe[T]):
def __init__(self, value: T) -> None:
self.value = value
def __eq__(self, other: object):
return isinstance(other, Just) and self.value == other.value
def __repr__(self) -> str:
return f"<Just {self.value}>"
@ -51,6 +54,9 @@ class Nothing(Maybe[T]):
def __init__(self) -> None:
...
def __eq__(self, other: object):
return isinstance(other, Nothing)
def __repr__(self) -> str:
return "<Nothing>"

View file

@ -55,6 +55,9 @@ class Ok(Result[T, E]):
def __init__(self, value: T) -> None:
self.value = value
def __eq__(self, other: object):
return isinstance(other, Ok) and self.value == other.value
def __repr__(self) -> str:
return f"<Ok {self.value}>"
@ -63,6 +66,9 @@ class Err(Result[T, E]):
def __init__(self, err: E) -> None:
self.err = err
def __eq__(self, other: object):
return isinstance(other, Err) and self.err == other.err
def __repr__(self) -> str:
return f"<Err {self.err}>"

View file

@ -1,4 +1,4 @@
from monads.maybe import Just, Nothing, maybe
from monads.maybe import Maybe, Just, Nothing, maybe
def test_maybe_none():

39
tests/test_monads.py Normal file
View file

@ -0,0 +1,39 @@
import pytest # type: ignore
from typing import List, Type
from monads import Monad, Maybe, Result
@pytest.fixture(scope="module", params=[Maybe, Result])
def monad(request) -> Type:
return request.param
def test_bind(monad) -> None:
expected: Monad[int] = monad.unit(2)
assert expected == monad.unit(1).bind(lambda x: monad.unit(x + 1))
def test_left_identity(monad) -> None:
n: int = 3
def f(n: int) -> Monad[int]:
return monad.unit(n * 3)
assert monad.unit(n).bind(f) == f(n)
def test_right_identity(monad) -> None:
m: Monad[int] = monad.unit(3)
assert m == m.bind(lambda x: monad.unit(x))
def test_associativity(monad) -> None:
m: Monad[int] = monad.unit(3)
def f(n: int) -> Monad[int]:
return monad.unit(n * 3)
def g(n: int) -> Monad[int]:
return monad.unit(n + 5)
assert m.bind(f).bind(g) == m.bind(lambda x: f(x).bind(g))