Move currying to its own module and give it types

This commit is contained in:
Correl Roush 2019-01-06 03:50:04 -05:00 committed by Correl Roush
parent 71db6f311e
commit 116c3cc1ac
5 changed files with 192 additions and 33 deletions

176
monads/currying.py Normal file
View file

@ -0,0 +1,176 @@
from functools import reduce, update_wrapper
from typing import Callable, Generic, NewType, TypeVar, overload
from .reader import Reader
A = TypeVar("A")
B = TypeVar("B")
C = TypeVar("C")
D = TypeVar("D")
E = TypeVar("E")
Result = TypeVar("Result")
class CurriedUnary(Reader[A, Result]):
...
class CurriedBinary(Reader[A, CurriedUnary[B, Result]]):
@overload
def __call__(self, environment: A) -> CurriedUnary[B, Result]:
...
@overload
def __call__(self, environment: A, b: B) -> Result:
...
def __call__(self, *args):
return reduce(lambda f, x: f(x), args, self.function)
class CurriedTernary(Reader[A, CurriedBinary[B, C, Result]]):
@overload
def __call__(self, environment: A) -> CurriedBinary[B, C, Result]:
...
@overload
def __call__(self, environment: A, b: B) -> CurriedUnary[C, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C) -> Result:
...
def __call__(self, *args):
return reduce(lambda f, x: f(x), args, self.function)
class CurriedQuaternary(Reader[A, CurriedTernary[B, C, D, Result]]):
@overload
def __call__(self, environment: A) -> CurriedTernary[B, C, D, Result]:
...
@overload
def __call__(self, environment: A, b: B) -> CurriedBinary[C, D, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C) -> CurriedUnary[D, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C, d: D) -> Result:
...
def __call__(self, *args):
return reduce(lambda f, x: f(x), args, self.function)
class CurriedQuinary(Reader[A, CurriedQuaternary[B, C, D, E, Result]]):
@overload
def __call__(self, environment: A) -> CurriedQuaternary[B, C, D, E, Result]:
...
@overload
def __call__(self, environment: A, b: B) -> CurriedTernary[C, D, E, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C) -> CurriedBinary[D, E, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C, d: D) -> CurriedUnary[E, Result]:
...
@overload
def __call__(self, environment: A, b: B, c: C, d: D, e: E) -> Result:
...
def __call__(self, *args):
return reduce(lambda f, x: f(x), args, self.function)
@overload
def curry(f: Callable[[A], Result]) -> CurriedUnary[A, Result]:
...
@overload
def curry(f: Callable[[A, B], Result]) -> CurriedBinary[A, B, Result]:
...
@overload
def curry(f: Callable[[A, B, C], Result]) -> CurriedTernary[A, B, C, Result]:
...
@overload
def curry(f: Callable[[A, B, C, D], Result]) -> CurriedQuaternary[A, B, C, D, Result]:
...
@overload
def curry(
f: Callable[[A, B, C, D, E], Result]
) -> CurriedQuinary[A, B, C, D, E, Result]:
...
def curry(f):
def wrapped(args, remaining):
if remaining < 1:
raise ValueError("Function must take one or more positional arguments")
elif remaining == 1:
curried = lambda x: f(*(args + [x]))
return CurriedUnary(update_wrapper(curried, f))
else:
curried = lambda x: wrapped(args + [x], remaining - 1)
if remaining == 2:
return CurriedBinary(update_wrapper(curried, f))
elif remaining == 3:
return CurriedTernary(update_wrapper(curried, f))
elif remaining == 4:
return CurriedQuaternary(update_wrapper(curried, f))
elif remaining == 5:
return CurriedQuinary(update_wrapper(curried, f))
else:
raise ValueError("Cannot curry a function with more than 5 arguments")
return wrapped([], f.__code__.co_argcount)
@overload
def uncurry(f: CurriedUnary[A, Result]) -> Callable[[A], Result]:
...
@overload
def uncurry(f: CurriedBinary[A, B, Result]) -> Callable[[A, B], Result]:
...
@overload
def uncurry(f: CurriedTernary[A, B, C, Result]) -> Callable[[A, B, C], Result]:
...
@overload
def uncurry(f: CurriedQuaternary[A, B, C, D, Result]) -> Callable[[A, B, C, D], Result]:
...
@overload
def uncurry(
f: CurriedQuinary[A, B, C, D, E, Result]
) -> Callable[[A, B, C, D, E], Result]:
...
def uncurry(f):
def wrapped(*args):
return reduce(lambda _f, x: _f(x), args, f)
return update_wrapper(wrapped, f)

View file

@ -56,19 +56,3 @@ class Reader(Monad[T], Generic[Env, T]):
__mul__ = __rmul__ = map
__rshift__ = bind
__and__ = lambda other, self: Reader.apply(self, other)
class Curried(Reader[Env, T]):
def __call__(self, *args):
return reduce(lambda f, x: f(x), args, self.function)
def curry(f: Callable):
def wrapped(args, remaining):
if remaining == 0:
return f(*args)
else:
curried = lambda x: wrapped(args + [x], remaining - 1)
return Curried(update_wrapper(curried, f))
return wrapped([], f.__code__.co_argcount)

16
tests/test_currying.py Normal file
View file

@ -0,0 +1,16 @@
from monads.currying import curry
def test_curry() -> None:
def add3(a: int, b: int, c: int) -> int:
return a + b + c
assert add3(1, 2, 3) == curry(add3)(1)(2)(3)
def test_call_curried_function_with_multiple_arguments() -> None:
@curry
def add5(a: int, b: int, c: int, d: int, e: int) -> int:
return a + b + c + d + e
assert add5(1)(2)(3)(4)(5) == add5(1, 2, 3, 4, 5)

View file

@ -2,7 +2,6 @@ import pytest # type: ignore
from typing import Any, Callable, List, TypeVar
from monads import Functor, Applicative, Future
from monads.reader import curry
T = TypeVar("T")
S = TypeVar("S")

View file

@ -2,7 +2,6 @@ import pytest # type: ignore
from typing import Any, Callable, List, TypeVar
from monads import Functor, Applicative, Reader
from monads.reader import curry
T = TypeVar("T")
S = TypeVar("S")
@ -87,18 +86,3 @@ def test_monad_associativity() -> None:
return Reader.pure(n + 5)
assert m.bind(f).bind(g)(0) == m.bind(lambda x: f(x).bind(g))(0)
def test_curry() -> None:
def add3(a, b, c):
return a + b + c
assert add3(1, 2, 3) == curry(add3)(1)(2)(3)
def test_call_curried_function_with_multiple_arguments() -> None:
@curry
def add3(a, b, c):
return a + b + c
assert add3(1, 2)(3) == add3(1, 2, 3)