module Flite.LambdaLift (lambdaLift) where
import Flite.Syntax
import Flite.Traversals
import Flite.Descend
import Flite.WriterState
import Control.Monad
lambdaLift :: Prog -> Prog
lambdaLift = concatMap liftDecl
type Lift a = WriterState Decl Int a
liftDecl :: Decl -> [Decl]
liftDecl (Func f args rhs) = Func f args rhs' : ds
where
(_, ds, rhs') = runWS (lift f rhs) 0
lift :: Id -> Exp -> Lift Exp
lift f (Lam [] e) = lift f e
lift f (Lam vs e) =
do let ws = filter (`notElem` vs) (freeVars e)
i <- get
set (i+1)
let f' = f ++ "^" ++ show i
e' <- lift f e
write (Func f' (map Var (ws ++ vs)) e')
return (App (Fun f') (map Var ws))
lift f e = descendM (lift f) e