Adjust signatures of curried functions

This commit is contained in:
Correl Roush 2019-01-08 14:52:19 -05:00
parent a45af9b5b0
commit 375e5a47e1
2 changed files with 40 additions and 7 deletions

View file

@ -1,4 +1,5 @@
from functools import reduce, update_wrapper
import inspect
from typing import Callable, Generic, NewType, TypeVar, overload
from .reader import Reader
@ -120,22 +121,31 @@ def curry(
def curry(f):
signature = inspect.signature(f)
parameters = list(signature.parameters.values())
def wrapped(args, remaining):
if remaining < 1:
raise ValueError("Function must take one or more positional arguments")
elif remaining == 1:
curried = lambda x: f(*(args + [x]))
return CurriedUnary(update_wrapper(curried, f))
curried = update_wrapper(lambda x: f(*(args + [x])), f)
curried.__signature__ = signature.replace(
parameters=parameters[-remaining:]
)
return CurriedUnary(curried)
else:
curried = lambda x: wrapped(args + [x], remaining - 1)
curried = update_wrapper(lambda x: wrapped(args + [x], remaining - 1), f)
curried.__signature__ = signature.replace(
parameters=parameters[-remaining:]
)
if remaining == 2:
return CurriedBinary(update_wrapper(curried, f))
return CurriedBinary(curried)
elif remaining == 3:
return CurriedTernary(update_wrapper(curried, f))
return CurriedTernary(curried)
elif remaining == 4:
return CurriedQuaternary(update_wrapper(curried, f))
return CurriedQuaternary(curried)
elif remaining == 5:
return CurriedQuinary(update_wrapper(curried, f))
return CurriedQuinary(curried)
else:
raise ValueError("Cannot curry a function with more than 5 arguments")

View file

@ -1,3 +1,4 @@
import inspect
from monads.currying import curry
@ -14,3 +15,25 @@ def test_call_curried_function_with_multiple_arguments() -> None:
return a + b + c + d + e
assert add5(1)(2)(3)(4)(5) == add5(1, 2, 3, 4, 5)
def test_curried_function_annotation_matches_original_function() -> None:
def add3(a: int, b: int, c: int) -> int:
return a + b + c
assert inspect.signature(add3) == inspect.signature(curry(add3))
def test_curried_function_annotation_drops_arguments_as_it_is_applied() -> None:
def add3(a: int, b: int, c: int) -> int:
return a + b + c
assert inspect.Signature(
[
inspect.Parameter(
param, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int
)
for param in ["b", "c"]
],
return_annotation=int,
) == inspect.signature(curry(add3)(1))