mirror of
https://github.com/correl/typesafe-monads.git
synced 2024-11-24 11:09:58 +00:00
Move currying to its own module and give it types
This commit is contained in:
parent
71db6f311e
commit
116c3cc1ac
5 changed files with 192 additions and 33 deletions
176
monads/currying.py
Normal file
176
monads/currying.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
from functools import reduce, update_wrapper
|
||||
from typing import Callable, Generic, NewType, TypeVar, overload
|
||||
|
||||
from .reader import Reader
|
||||
|
||||
A = TypeVar("A")
|
||||
B = TypeVar("B")
|
||||
C = TypeVar("C")
|
||||
D = TypeVar("D")
|
||||
E = TypeVar("E")
|
||||
|
||||
Result = TypeVar("Result")
|
||||
|
||||
|
||||
class CurriedUnary(Reader[A, Result]):
|
||||
...
|
||||
|
||||
|
||||
class CurriedBinary(Reader[A, CurriedUnary[B, Result]]):
|
||||
@overload
|
||||
def __call__(self, environment: A) -> CurriedUnary[B, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B) -> Result:
|
||||
...
|
||||
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
class CurriedTernary(Reader[A, CurriedBinary[B, C, Result]]):
|
||||
@overload
|
||||
def __call__(self, environment: A) -> CurriedBinary[B, C, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B) -> CurriedUnary[C, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C) -> Result:
|
||||
...
|
||||
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
class CurriedQuaternary(Reader[A, CurriedTernary[B, C, D, Result]]):
|
||||
@overload
|
||||
def __call__(self, environment: A) -> CurriedTernary[B, C, D, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B) -> CurriedBinary[C, D, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C) -> CurriedUnary[D, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C, d: D) -> Result:
|
||||
...
|
||||
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
class CurriedQuinary(Reader[A, CurriedQuaternary[B, C, D, E, Result]]):
|
||||
@overload
|
||||
def __call__(self, environment: A) -> CurriedQuaternary[B, C, D, E, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B) -> CurriedTernary[C, D, E, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C) -> CurriedBinary[D, E, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C, d: D) -> CurriedUnary[E, Result]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, environment: A, b: B, c: C, d: D, e: E) -> Result:
|
||||
...
|
||||
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
@overload
|
||||
def curry(f: Callable[[A], Result]) -> CurriedUnary[A, Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def curry(f: Callable[[A, B], Result]) -> CurriedBinary[A, B, Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def curry(f: Callable[[A, B, C], Result]) -> CurriedTernary[A, B, C, Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def curry(f: Callable[[A, B, C, D], Result]) -> CurriedQuaternary[A, B, C, D, Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def curry(
|
||||
f: Callable[[A, B, C, D, E], Result]
|
||||
) -> CurriedQuinary[A, B, C, D, E, Result]:
|
||||
...
|
||||
|
||||
|
||||
def curry(f):
|
||||
def wrapped(args, remaining):
|
||||
if remaining < 1:
|
||||
raise ValueError("Function must take one or more positional arguments")
|
||||
elif remaining == 1:
|
||||
curried = lambda x: f(*(args + [x]))
|
||||
return CurriedUnary(update_wrapper(curried, f))
|
||||
else:
|
||||
curried = lambda x: wrapped(args + [x], remaining - 1)
|
||||
if remaining == 2:
|
||||
return CurriedBinary(update_wrapper(curried, f))
|
||||
elif remaining == 3:
|
||||
return CurriedTernary(update_wrapper(curried, f))
|
||||
elif remaining == 4:
|
||||
return CurriedQuaternary(update_wrapper(curried, f))
|
||||
elif remaining == 5:
|
||||
return CurriedQuinary(update_wrapper(curried, f))
|
||||
else:
|
||||
raise ValueError("Cannot curry a function with more than 5 arguments")
|
||||
|
||||
return wrapped([], f.__code__.co_argcount)
|
||||
|
||||
|
||||
@overload
|
||||
def uncurry(f: CurriedUnary[A, Result]) -> Callable[[A], Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def uncurry(f: CurriedBinary[A, B, Result]) -> Callable[[A, B], Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def uncurry(f: CurriedTernary[A, B, C, Result]) -> Callable[[A, B, C], Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def uncurry(f: CurriedQuaternary[A, B, C, D, Result]) -> Callable[[A, B, C, D], Result]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def uncurry(
|
||||
f: CurriedQuinary[A, B, C, D, E, Result]
|
||||
) -> Callable[[A, B, C, D, E], Result]:
|
||||
...
|
||||
|
||||
|
||||
def uncurry(f):
|
||||
def wrapped(*args):
|
||||
return reduce(lambda _f, x: _f(x), args, f)
|
||||
|
||||
return update_wrapper(wrapped, f)
|
|
@ -56,19 +56,3 @@ class Reader(Monad[T], Generic[Env, T]):
|
|||
__mul__ = __rmul__ = map
|
||||
__rshift__ = bind
|
||||
__and__ = lambda other, self: Reader.apply(self, other)
|
||||
|
||||
|
||||
class Curried(Reader[Env, T]):
|
||||
def __call__(self, *args):
|
||||
return reduce(lambda f, x: f(x), args, self.function)
|
||||
|
||||
|
||||
def curry(f: Callable):
|
||||
def wrapped(args, remaining):
|
||||
if remaining == 0:
|
||||
return f(*args)
|
||||
else:
|
||||
curried = lambda x: wrapped(args + [x], remaining - 1)
|
||||
return Curried(update_wrapper(curried, f))
|
||||
|
||||
return wrapped([], f.__code__.co_argcount)
|
||||
|
|
16
tests/test_currying.py
Normal file
16
tests/test_currying.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from monads.currying import curry
|
||||
|
||||
|
||||
def test_curry() -> None:
|
||||
def add3(a: int, b: int, c: int) -> int:
|
||||
return a + b + c
|
||||
|
||||
assert add3(1, 2, 3) == curry(add3)(1)(2)(3)
|
||||
|
||||
|
||||
def test_call_curried_function_with_multiple_arguments() -> None:
|
||||
@curry
|
||||
def add5(a: int, b: int, c: int, d: int, e: int) -> int:
|
||||
return a + b + c + d + e
|
||||
|
||||
assert add5(1)(2)(3)(4)(5) == add5(1, 2, 3, 4, 5)
|
|
@ -2,7 +2,6 @@ import pytest # type: ignore
|
|||
from typing import Any, Callable, List, TypeVar
|
||||
|
||||
from monads import Functor, Applicative, Future
|
||||
from monads.reader import curry
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
|
|
@ -2,7 +2,6 @@ import pytest # type: ignore
|
|||
from typing import Any, Callable, List, TypeVar
|
||||
|
||||
from monads import Functor, Applicative, Reader
|
||||
from monads.reader import curry
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
@ -87,18 +86,3 @@ def test_monad_associativity() -> None:
|
|||
return Reader.pure(n + 5)
|
||||
|
||||
assert m.bind(f).bind(g)(0) == m.bind(lambda x: f(x).bind(g))(0)
|
||||
|
||||
|
||||
def test_curry() -> None:
|
||||
def add3(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert add3(1, 2, 3) == curry(add3)(1)(2)(3)
|
||||
|
||||
|
||||
def test_call_curried_function_with_multiple_arguments() -> None:
|
||||
@curry
|
||||
def add3(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert add3(1, 2)(3) == add3(1, 2, 3)
|
||||
|
|
Loading…
Reference in a new issue