diff --git a/monads/__init__.py b/monads/__init__.py index a1f7c9e..e5c7f92 100644 --- a/monads/__init__.py +++ b/monads/__init__.py @@ -2,6 +2,7 @@ from .functor import Functor from .applicative import Applicative from .monad import Monad from .list import List +from .set import Set from .maybe import Maybe, Just, Nothing from .result import Result, Ok, Err from .future import Future diff --git a/monads/set.py b/monads/set.py new file mode 100644 index 0000000..11c3e5a --- /dev/null +++ b/monads/set.py @@ -0,0 +1,99 @@ +from __future__ import annotations +from functools import reduce +from itertools import chain +from monads import functor, List +from typing import ( + Callable, + Iterable, + Iterator, + Set as _Set, + List as _List, + Optional, + TypeVar, + Union, + cast, +) + +from .monad import Monad +from .monoid import Monoidal +from .currying import CurriedBinary, uncurry + +T = TypeVar("T") +S = TypeVar("S") + + +class Set(Monad[T], Monoidal[set]): + @classmethod + def pure(cls, value: T) -> Set[T]: + t = set() + t.add(value) + return Set(t) + + def bind(self, function: Callable[[T], Set[S]]) -> Set[S]: + return reduce(Set.mappend, map(function, self.value), Set.mzero()) + + def map(self, function: Callable[[T], S]) -> Set[S]: + return Set(set(map(function, self.value))) + + def apply(self, functor: Set[Callable[[T], S]]) -> Set[S]: + + return Set( + set(chain.from_iterable([map(f, self.value) for f in functor.value])) + ) + + @classmethod + def mzero(cls) -> Set[T]: + return cls(set()) + + @classmethod + def sequence(cls, xs: Iterable[Set[T]]) -> Set[_List[T]]: + """Evaluate monadic actions in sequence, collecting results.""" + + def mcons(acc: Set[_Set[T]], x: Set[T]) -> Set[_Set[T]]: + return acc.bind(lambda acc_: x.map(lambda x_: acc_.union(set([x_])))) + + empty: Set[_Set[T]] = Set.pure(set()) + return Set(set(reduce(mcons, xs, empty))) # type: ignore + + def flatten(self) -> Set[T]: + def flat(acc: Set[T], element: T) -> Set[T]: + if element and isinstance(element, Iterable): + for k in element: + acc = acc.mappend(Set(set([k]))) + elif element: + acc = acc.mappend(Set(set([element]))) + return acc + + return Set(reduce(flat, self, Set.mzero())) # type: ignore + + def sort(self, key: Optional[str] = None, reverse: bool = False) -> Set[T]: + lst_copy = self.value.copy() + lst_copy.sort(key=key, reverse=reverse) # type: ignore + return Set(lst_copy) + + def fold( + self, func: Union[Callable[[S, T], S], CurriedBinary[S, T, S]], base_val: S + ) -> S: + if isinstance(func, CurriedBinary): + functor = uncurry(cast(CurriedBinary, func)) + else: + functor = func + return reduce(functor, self.value, base_val) # type: ignore + + __and__ = lambda other, self: Set.apply(self, other) # type: ignore + + def mappend(self, other: Set[T]) -> Set[T]: + return Set(self.value.union(other.value)) + + __add__ = mappend + __mul__ = __rmul__ = map + __rshift__ = bind + + def __sizeof__(self) -> int: + return self.value.__sizeof__() + + def __len__(self) -> int: + return len(set(self.value)) + + def __iter__(self) -> Iterator[T]: + return iter(self.value) diff --git a/tests/fixtures.py b/tests/fixtures.py index 79936cf..4caf50c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,8 +1,8 @@ import pytest # type: ignore from typing import Type -from monads import Maybe, List, Result +from monads import Maybe, List, Result, Set -@pytest.fixture(scope="module", params=[Maybe, List, Result]) +@pytest.fixture(scope="module", params=[Maybe, List, Result, Set]) def monad(request) -> Type: return request.param