module Flite.Let(inlineLinearLet, inlineSimpleLet, liftLet) where

import Flite.Syntax
import Flite.Traversals
import Flite.Descend
import Flite.Fresh
import List

mkLet :: [Binding] -> Exp -> Exp
mkLet [] e = e
mkLet bs e = Let bs e

inlineLetWhen :: ([Binding] -> Exp -> Binding -> Bool) -> Prog -> Fresh Prog
inlineLetWhen f p = onExpM freshen p >>= return . onExp inline
  where
    inline (Let bs e) = mkLet (zip vs1 (map inline es1')) (inline e')
      where (vs, es) = unzip bs
            (bs0, bs1) = partition (f bs e) bs
            (vs1, es1) = unzip bs1
            (e':es1') = foldr (\(v, e) -> map (subst e v)) (e:es1) bs0
    inline e = descend inline e

inlineLinearLet :: Prog -> Fresh Prog
inlineLinearLet = inlineLetWhen linear
  where
    linear bs e (v, _) = refs v (e:map snd bs) <= 1
    refs v es = sum (map (varRefs v) es)

inlineSimpleLet :: Prog -> Fresh Prog
inlineSimpleLet = inlineLetWhen simple
  where
    simple _ _ (_, rhs) = simp rhs
    simp (App e []) = simp e
    simp (App e es) = False
    simp (Case e as) = False
    simp _ = True

liftLet :: Prog -> Fresh Prog
liftLet p = do p' <- onExpM freshen p
               return (onExp lift p')
  where
    lift e = mkLet [(v, liftInCase rhs) | (v, rhs) <- binds e]
                   (liftInCase (dropBinds e))

    liftInCase (Case e as) = Case e [(p, lift e) | (p, e) <- as]
    liftInCase e = descend liftInCase e

    dropBinds (Let bs e) = dropBinds e
    dropBinds (Case e as) = Case (dropBinds e) as
    dropBinds e = descend dropBinds e

    binds (Let bs e) = binds e ++ [(v, dropBinds e) | (v, e) <- bs]
                               ++ concatMap (binds . snd) bs
    binds (Case e as) = binds e
    binds e = extract binds e