diff --git a/include/monads.lfe b/include/monads.lfe index ffc8913..b8a0aa2 100644 --- a/include/monads.lfe +++ b/include/monads.lfe @@ -4,16 +4,25 @@ (monad:do-transform monad statements))) (defmacro >>= (monad m f) - `(: ,monad >>= ,m ,f)) + (if (: lfe-utils atom? monad) + `(call ',monad '>>= ,m ,f) + `(call ,monad '>>= ,m ,f))) (defmacro >> (monad m1 m2) - `(: ,monad >>= ,m1 (lambda (_) ,m2))) + (let ((f `(lambda (_) ,m2))) + (if (: lfe-utils atom? monad) + `(call ',monad '>>= ,m1 ,f) + `(call ,monad '>>= ,m1 ,f)))) (defmacro return (monad expr) - `(: ,monad return ,expr)) + (if (: lfe-utils atom? monad) + `(call ',monad 'return ,expr) + `(call ,monad 'return ,expr))) (defmacro fail (monad expr) - `(: ,monad fail ,expr)) + (if (: lfe-utils atom? monad) + `(call ',monad 'fail ,expr) + `(call ,monad 'fail ,expr))) (defmacro sequence (monad list) `(: lists foldr diff --git a/src/monad.lfe b/src/monad.lfe index 00e7dd2..291ec8f 100644 --- a/src/monad.lfe +++ b/src/monad.lfe @@ -10,7 +10,7 @@ (defun do-transform ((monad (cons h '())) h) - ((monad (cons (list f '<- m) t)) (list ': monad '>>= + ((monad (cons (list f '<- m) t)) (list '>>= monad m (list 'lambda (list f) (do-transform monad t)))) ((monad (cons h t)) (list '>> monad h (do-transform monad t)))