module Curry.ExtendedFlat.LiftLetrec(liftLetrecProg) where
import Data.List
import Control.Monad.State
import Data.Maybe
import qualified Data.Map as Map
import qualified Data.Set as Set
import Curry.ExtendedFlat.Type
import Curry.ExtendedFlat.Goodies
import Curry.ExtendedFlat.MonadicGoodies
data LifterState = LifterState { modname :: String,
currentFunc :: String,
globals :: Set.Set QName,
globalCounter :: Map.Map QName Int,
localCounter :: Int,
lifted :: Map.Map QName FuncDecl }
type Bind = (VarIndex, Expr)
type LiftMonad = State LifterState
liftLetrecProg :: Prog -> Prog
liftLetrecProg prog = updProg id id id (++ fdecls) id prog'
where state = LifterState {
modname = progName prog,
currentFunc = "anonymous",
globals = Set.fromList g,
globalCounter = Map.fromList $ zip g (repeat 1),
localCounter = 0,
lifted = Map.empty
}
g = allGlobals prog
(prog', state') = runState (updProgFuncsM run prog) state
fdecls = Map.elems (lifted state')
run fdecl = do
let fname = localName (funcName fdecl)
modify (\st -> st { currentFunc = fname,
localCounter = (maximum . map idxOf . allVarsInFunc) fdecl
})
fdecl' <- updFuncLetsM liftRecursion fdecl
modify (\st -> st {currentFunc = "anonymous"})
return fdecl'
liftRecursion :: [Bind] -> Expr -> LiftMonad Expr
liftRecursion [(b, rhs)] body
| b `elem` fv = do globalcall <- mkLiftedFunction (typeofVar b) b rhs (fv \\ [b])
return (Let [(b, globalcall)] body)
| otherwise = return (Let [(b, rhs)] body)
where fv = fvs rhs
liftRecursion bs body = return (Let bs body)
mkLiftedFunction :: Maybe TypeExpr -> VarIndex -> Expr -> [VarIndex] -> LiftMonad Expr
mkLiftedFunction t v rhs fv
= do name <- newGlobalName t
st <- get
let fcall = (Comb FuncCall name (map Var fv))
let fdecl = Func name (length fv) Private (fromMaybe (TVar 0) t) (Rule fv (Let [(v,fcall)] rhs))
put st { lifted = Map.insert name fdecl (lifted st),
globals = Set.insert name (globals st)
}
return fcall
newGlobalName :: Maybe TypeExpr -> LiftMonad QName
newGlobalName t
= do st <- get
let qn = QName Nothing t (modname st) (currentFunc st)
let counter = Map.findWithDefault 1 qn (globalCounter st)
put st { globalCounter = Map.insert qn (counter + 1) (globalCounter st) }
let qn' = QName Nothing t (modname st) (localName qn ++ "_" ++ show counter)
if qn' `Set.member` globals st
then newGlobalName t
else return qn'
allGlobals :: Prog -> [QName]
allGlobals prog = [n | Func n _ _ _ _ <- fs]
where fs = progFuncs prog