module Curry.ExtendedFlat.UnMutual(unMutualProg) where
import Data.Graph
import Data.Maybe
import Data.List
import Control.Monad.State
import Curry.Base.Position(noRef)
import Curry.ExtendedFlat.Type
import Curry.ExtendedFlat.Goodies
import Curry.ExtendedFlat.MonadicGoodies
type Bind = (VarIndex, Expr)
newtype UnMutualState = UnMutualState { localCounter :: Int }
type UnMutualMonad = State UnMutualState
unMutualProg :: Prog -> Prog
unMutualProg p = evalState (updProgFuncsM
(\fdecl -> do
modify (\st -> st { localCounter = (maximum . map idxOf . allVarsInFunc) fdecl})
updFuncLetsM rmMutualRecursion fdecl)
p) (UnMutualState 1000)
rmMutualRecursion :: [Bind] -> Expr -> UnMutualMonad Expr
rmMutualRecursion bs body
| allWhnf bs || length bs <= 1
= return (Let bs body)
| otherwise
= mdo (body', bound, fbs) <- partitionBinds (fvs body) sccs (body, mkTuple fbs, [])
mkSingleLet body' bound fbs
where fvsGraph = depGraph bs
sccs = sortSccs fvsGraph
mkSingleLet :: Expr -> Expr -> [VarIndex] -> UnMutualMonad Expr
mkSingleLet e2 e1 [v]
= return (Let [(v, e1)] e2)
mkSingleLet body bound fbs
= do recname <- newLocalName (Just fbsType)
bound' <- mkFbSelectors recname bound
body' <- mkFbSelectors recname body
return (Let [(recname, bound')] body')
where
fbsType = TCons (mkQName tuplecon) (map (fromJust . typeofVar) fbs)
tuplecon = ("Prelude", "(" ++ replicate (length fbs 1 ) ',' ++ ")")
mkFbSelectors recname b = foldM (mkSelector recname)b fbs
mkSelector recname b v = nonrecLet v (mkSel (Var recname) v fbs) b
nonrecLet :: VarIndex -> Expr -> Expr -> UnMutualMonad Expr
nonrecLet x e1 e2
| x `elem` allVars e1
= do vi <- newLocalName (typeofVar x)
let e2' = subst x (Var vi) e2
return (Let [(vi,e1)] e2')
| otherwise = return (Let [(x,e1)] e2)
mkTuple :: [VarIndex] -> Expr
mkTuple [e] = Var e
mkTuple es = Comb ConsCall (mkTupleConstr es) $ map Var es
mkTupleConstr :: [a] -> QName
mkTupleConstr arity = curry mkQName "Prelude" ("(" ++ replicate (length arity1) ',' ++ ")")
mkSel :: Expr -> VarIndex -> [VarIndex] -> Expr
mkSel e v vs = Case noRef Rigid e [Branch pat (Var v)]
where pat = Pattern tcon vs
tcon = mkTupleConstr vs
allWhnf :: [Bind] -> Bool
allWhnf = all (whnf . snd)
type FvsNode = (Bind, VarIndex, [VarIndex])
depGraph :: [Bind] -> [FvsNode]
depGraph = map (\(x, e) -> ((x, e), x, fvs e))
sortSccs :: [FvsNode] -> [SCC FvsNode]
sortSccs = reverse . stronglyConnCompR
partitionBinds :: [VarIndex] -> [SCC FvsNode]
-> (Expr, Expr, [VarIndex])
-> UnMutualMonad (Expr, Expr, [VarIndex])
partitionBinds pull (CyclicSCC []:ds) part
= partitionBinds pull ds part
partitionBinds pull (CyclicSCC d:ds) (body, bound, fbs)
= let (b@(v,e), d') = pickFbNode pull d
sccs = sortSccs d' ++ ds
in do l <- nonrecLet v e bound
partitionBinds pull sccs (body, l, fst b:fbs)
partitionBinds pull (AcyclicSCC ((x,e),_,r):ds) (body, bound, fbs)
= do l <- nonrecLet x e bound
(body', pull') <- if x `elem` pull
then do l' <- nonrecLet x e body
return (l', r `union` pull)
else return (body, pull)
partitionBinds pull' ds (body', l, fbs)
partitionBinds _pull [] part
= return part
pickFbNode :: [VarIndex] -> [FvsNode] -> (Bind, [FvsNode])
pickFbNode pull defs = (b, d)
where
ds = [x | (_, x, _) <- defs]
(b, y, _) = maximumBy (compare `on` weight pull ds) defs
d = [ n | n@(_, x, _) <- defs, x /= y]
on :: (b -> b -> c) -> (a -> b) -> a -> a -> c
on (.*.) f x y = f x .*. f y
weight :: [VarIndex] -> [VarIndex] -> FvsNode -> (Bool, Int, Bool)
weight pull defs (_,x,fv) = (recursive, length incoming, pulled)
where recursive = x `elem` fv
incoming = fv `intersect` defs
pulled = x `elem` pull
newLocalName :: Maybe TypeExpr -> UnMutualMonad VarIndex
newLocalName t
= do st <- get
let counter = 1 + localCounter st
put st { localCounter = counter }
return (VarIndex t counter)
subst :: VarIndex -> Expr -> Expr -> Expr
subst v x = po
where po e@(Var v')
| v==v' = x
| otherwise = e
po e@(Lit _)
= e
po (Comb t n es)
= Comb t n (map po es)
po e@(Free vs e')
| v `elem` vs = e
| otherwise = Free vs (po e')
po e@(Let bs e')
| lookup v bs == Nothing
= Let (map poBind bs) (po e')
| otherwise = e
po (Or l r) = Or (po l) (po r)
po (Case p t e bs) = Case p t (po e) (map poBranch bs)
poBind (w, rhs) = (w, po rhs)
poBranch e@(Branch p rhs)
| v `elem` trPattern (\_ args -> args) (const []) p
= e
| otherwise
= Branch p (po rhs)