mirror of
https://github.com/correl/typesafe-monads.git
synced 2025-04-06 01:04:24 -09:00
107 lines
3 KiB
Python
107 lines
3 KiB
Python
from __future__ import annotations
|
|
from functools import reduce
|
|
from itertools import chain
|
|
from monads import functor, List
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Iterable,
|
|
Iterator,
|
|
Set as _Set,
|
|
List as _List,
|
|
Optional,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from .monad import Monad
|
|
from .monoid import Monoidal
|
|
from .currying import CurriedBinary, uncurry
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
|
|
|
|
class Set(Monad[T], Monoidal[set]):
|
|
@classmethod
|
|
def pure(cls, value: T) -> Set[T]:
|
|
def unpack(k: T) -> set:
|
|
s: set = set()
|
|
if isinstance(k, Iterable):
|
|
for v in k:
|
|
s.union(unpack(v))
|
|
else:
|
|
s.add(k)
|
|
return s
|
|
|
|
return Set(unpack(value))
|
|
|
|
def bind(self, function: Callable[[T], Set[S]]) -> Set[S]:
|
|
return reduce(Set.mappend, map(function, self.value), Set.mzero())
|
|
|
|
def map(self, function: Callable[[T], S]) -> Set[S]:
|
|
return Set(set(map(function, self.value)))
|
|
|
|
def apply(self, functor: Set[Callable[[T], S]]) -> Set[S]:
|
|
|
|
return Set(
|
|
set(chain.from_iterable([map(f, self.value) for f in functor.value]))
|
|
)
|
|
|
|
@classmethod
|
|
def mzero(cls) -> Set[T]:
|
|
return cls(set())
|
|
|
|
@classmethod
|
|
def sequence(cls, xs: Iterable[Set[T]]) -> Set[_List[T]]:
|
|
"""Evaluate monadic actions in sequence, collecting results."""
|
|
|
|
def mcons(acc: Set[_Set[T]], x: Set[T]) -> Set[_Set[T]]:
|
|
return acc.bind(lambda acc_: x.map(lambda x_: acc_.union(set([x_]))))
|
|
|
|
empty: Set[_Set[T]] = Set.pure(set())
|
|
return Set(set(reduce(mcons, xs, empty))) # type: ignore
|
|
|
|
def flatten(self) -> Set[T]:
|
|
def flat(acc: Set[T], element: T) -> Set[T]:
|
|
if element and isinstance(element, Iterable):
|
|
for k in element:
|
|
acc = acc.mappend(Set(set([k])))
|
|
elif element:
|
|
acc = acc.mappend(Set(set([element])))
|
|
return acc
|
|
|
|
return Set(reduce(flat, self, Set.mzero())) # type: ignore
|
|
|
|
def sort(self, key: Optional[str] = None, reverse: bool = False) -> Set[T]:
|
|
lst_copy = self.value.copy()
|
|
lst_copy.sort(key=key, reverse=reverse) # type: ignore
|
|
return Set(lst_copy)
|
|
|
|
def fold(
|
|
self, func: Union[Callable[[S, T], S], CurriedBinary[S, T, S]], base_val: S
|
|
) -> S:
|
|
if isinstance(func, CurriedBinary):
|
|
functor = uncurry(cast(CurriedBinary, func))
|
|
else:
|
|
functor = func
|
|
return reduce(functor, self.value, base_val) # type: ignore
|
|
|
|
__and__ = lambda other, self: Set.apply(self, other) # type: ignore
|
|
|
|
def mappend(self, other: Set[T]) -> Set[T]:
|
|
return Set(self.value.union(other.value))
|
|
|
|
__add__ = mappend
|
|
__mul__ = __rmul__ = map
|
|
__rshift__ = bind
|
|
|
|
def __sizeof__(self) -> int:
|
|
return self.value.__sizeof__()
|
|
|
|
def __len__(self) -> int:
|
|
return len(set(self.value))
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
return iter(self.value)
|