mirror of
https://github.com/correl/typesafe-monads.git
synced 2024-11-24 11:09:58 +00:00
Adjust signatures of curried functions
This commit is contained in:
parent
a45af9b5b0
commit
375e5a47e1
2 changed files with 40 additions and 7 deletions
|
@ -1,4 +1,5 @@
|
|||
from functools import reduce, update_wrapper
|
||||
import inspect
|
||||
from typing import Callable, Generic, NewType, TypeVar, overload
|
||||
|
||||
from .reader import Reader
|
||||
|
@ -120,22 +121,31 @@ def curry(
|
|||
|
||||
|
||||
def curry(f):
|
||||
signature = inspect.signature(f)
|
||||
parameters = list(signature.parameters.values())
|
||||
|
||||
def wrapped(args, remaining):
|
||||
if remaining < 1:
|
||||
raise ValueError("Function must take one or more positional arguments")
|
||||
elif remaining == 1:
|
||||
curried = lambda x: f(*(args + [x]))
|
||||
return CurriedUnary(update_wrapper(curried, f))
|
||||
curried = update_wrapper(lambda x: f(*(args + [x])), f)
|
||||
curried.__signature__ = signature.replace(
|
||||
parameters=parameters[-remaining:]
|
||||
)
|
||||
return CurriedUnary(curried)
|
||||
else:
|
||||
curried = lambda x: wrapped(args + [x], remaining - 1)
|
||||
curried = update_wrapper(lambda x: wrapped(args + [x], remaining - 1), f)
|
||||
curried.__signature__ = signature.replace(
|
||||
parameters=parameters[-remaining:]
|
||||
)
|
||||
if remaining == 2:
|
||||
return CurriedBinary(update_wrapper(curried, f))
|
||||
return CurriedBinary(curried)
|
||||
elif remaining == 3:
|
||||
return CurriedTernary(update_wrapper(curried, f))
|
||||
return CurriedTernary(curried)
|
||||
elif remaining == 4:
|
||||
return CurriedQuaternary(update_wrapper(curried, f))
|
||||
return CurriedQuaternary(curried)
|
||||
elif remaining == 5:
|
||||
return CurriedQuinary(update_wrapper(curried, f))
|
||||
return CurriedQuinary(curried)
|
||||
else:
|
||||
raise ValueError("Cannot curry a function with more than 5 arguments")
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
from monads.currying import curry
|
||||
|
||||
|
||||
|
@ -14,3 +15,25 @@ def test_call_curried_function_with_multiple_arguments() -> None:
|
|||
return a + b + c + d + e
|
||||
|
||||
assert add5(1)(2)(3)(4)(5) == add5(1, 2, 3, 4, 5)
|
||||
|
||||
|
||||
def test_curried_function_annotation_matches_original_function() -> None:
|
||||
def add3(a: int, b: int, c: int) -> int:
|
||||
return a + b + c
|
||||
|
||||
assert inspect.signature(add3) == inspect.signature(curry(add3))
|
||||
|
||||
|
||||
def test_curried_function_annotation_drops_arguments_as_it_is_applied() -> None:
|
||||
def add3(a: int, b: int, c: int) -> int:
|
||||
return a + b + c
|
||||
|
||||
assert inspect.Signature(
|
||||
[
|
||||
inspect.Parameter(
|
||||
param, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int
|
||||
)
|
||||
for param in ["b", "c"]
|
||||
],
|
||||
return_annotation=int,
|
||||
) == inspect.signature(curry(add3)(1))
|
||||
|
|
Loading…
Reference in a new issue