from __future__ import annotations from functools import reduce from itertools import chain from monads import functor, List from typing import ( Any, Callable, Iterable, Iterator, Set as _Set, List as _List, Optional, TypeVar, Union, cast, ) from .monad import Monad from .monoid import Monoidal from .currying import CurriedBinary, uncurry T = TypeVar("T") S = TypeVar("S") class Set(Monad[T], Monoidal[set]): @classmethod def pure(cls, value: T) -> Set[T]: def unpack(k: T) -> set: s: set = set() if isinstance(k, Iterable): for v in k: s.union(unpack(v)) else: s.add(k) return s return Set(unpack(value)) def bind(self, function: Callable[[T], Set[S]]) -> Set[S]: return reduce(Set.mappend, map(function, self.value), Set.mzero()) def map(self, function: Callable[[T], S]) -> Set[S]: return Set(set(map(function, self.value))) def apply(self, functor: Set[Callable[[T], S]]) -> Set[S]: return Set( set(chain.from_iterable([map(f, self.value) for f in functor.value])) ) @classmethod def mzero(cls) -> Set[T]: return cls(set()) @classmethod def sequence(cls, xs: Iterable[Set[T]]) -> Set[_List[T]]: """Evaluate monadic actions in sequence, collecting results.""" def mcons(acc: Set[_Set[T]], x: Set[T]) -> Set[_Set[T]]: return acc.bind(lambda acc_: x.map(lambda x_: acc_.union(set([x_])))) empty: Set[_Set[T]] = Set.pure(set()) return Set(set(reduce(mcons, xs, empty))) # type: ignore def flatten(self) -> Set[T]: def flat(acc: Set[T], element: T) -> Set[T]: if element and isinstance(element, Iterable): for k in element: acc = acc.mappend(Set(set([k]))) elif element: acc = acc.mappend(Set(set([element]))) return acc return Set(reduce(flat, self, Set.mzero())) # type: ignore def sort(self, key: Optional[str] = None, reverse: bool = False) -> Set[T]: lst_copy = self.value.copy() lst_copy.sort(key=key, reverse=reverse) # type: ignore return Set(lst_copy) def fold( self, func: Union[Callable[[S, T], S], CurriedBinary[S, T, S]], base_val: S ) -> S: if isinstance(func, CurriedBinary): functor = uncurry(cast(CurriedBinary, func)) else: functor = func return reduce(functor, self.value, base_val) # type: ignore __and__ = lambda other, self: Set.apply(self, other) # type: ignore def mappend(self, other: Set[T]) -> Set[T]: return Set(self.value.union(other.value)) __add__ = mappend __mul__ = __rmul__ = map __rshift__ = bind def __sizeof__(self) -> int: return self.value.__sizeof__() def __len__(self) -> int: return len(set(self.value)) def __iter__(self) -> Iterator[T]: return iter(self.value)