-- | Flattening nested let and case expressions.
module DDC.Core.Transform.Flatten
        (flatten)
where
import DDC.Core.Transform.LiftT
import DDC.Core.Transform.TransformUpX
import DDC.Core.Transform.AnonymizeX
import DDC.Core.Transform.LiftX
import DDC.Core.Exp
import DDC.Core.Compounds
import DDC.Type.Predicates
import Data.Functor.Identity


-- | Flatten binding structure in a thing.
--
--   Flattens nested let-expressions, 
--   and single alternative let-case expressions. 
--
flatten :: Ord n 
        => (TransformUpMX Identity c)
        => c a n -> c a n
flatten 
 = {-# SCC flatten #-}
   transformUpX' flatten1


-- | Flatten a single nested let-expression.
flatten1
        :: Ord n
        => Exp a n
        -> Exp a n

-- Let ----------------------------------------------------
-- Flatten Nested Lets.
-- @
--      let b1 = (let b2 = def2 in x2) in
--      x1
--
--  ==> let b2 = def2 in 
--      let b1 = x2   in
--      x1
-- @
-- 
flatten1 (XLet a1 (LLet b1
            inner@(XLet a2 (LLet b2 def2) x2))
               x1)

 | isBName b2
 = flatten1
        $ XLet a1 (LLet b1 
               (anonymizeX inner))
               x1

 | otherwise
 = let  x1'       = liftAcrossX [b1] [b2] x1                                 
   in   XLet a2 (LLet b2 def2) 
      $ flatten1
      $ XLet a1 (LLet b1 x2) 
             x1'


-- Drag 'letregion' out of the top-level of a binding.
-- @
--    let b1 = letregion b2 in x2 in
--    x1
--
-- => letregion b2 in 
--    let b1 = x2 in
--    x1
-- @
--
-- NOTE: For region allocation this increases the lifetime of the region.
--       Maybe use a follow on transform to reduce the lifetime again.
--
flatten1 (XLet a1 (LLet b1
            inner@(XLet a2 (LPrivate b2 mt bs2) x2))
               x1)
 | all isBName b2
 = flatten1
        $ XLet a1 (LLet b1
                  (anonymizeX inner))
               x1

 | otherwise
 = let  x1'     = liftAcrossT []   b2
                $ liftAcrossX [b1] bs2 x1
   in   XLet a2 (LPrivate b2 mt bs2) 
      $ flatten1
      $ XLet a1 (LLet (zapX b1) x2) 
             x1'


-- Flatten single-alt case expressions.
-- @
--     let b1 = case x1 of 
--                P -> x2 
--     in x3
--
--  => case x1 of 
--       P -> let b1 = x2 
--            in x3
-- @
--
-- * binding must be strict because we force evaluation of x1.
--
flatten1 (XLet  a1 (LLet b1 
             inner@(XCase a2 x1 [AAlt p x2]))
                   x3)
 | any isBName $ bindsOfPat p
 = flatten1
        $ XLet  a1 (LLet b1
                   (anonymizeX inner))
                   x3

 | otherwise
 = let  x3'     = liftAcrossX [b1] (bindsOfPat p) x3
   in   XCase a2 x1 
           [AAlt p ( flatten1 
                   $ XLet a1 (LLet b1 x2)
                             (anonymizeX x3'))]


-- Any let, its bound expression doesn't contain a strict non-recursive
-- let so just flatten the body
flatten1 (XLet a1 llet1 x1)
 = XLet a1 llet1 (flatten1 x1)


-- Case ---------------------------------------------------
-- Flatten all the alternatives in a case-expression.
flatten1 (XCase a x1 alts)
 = XCase a (flatten1 x1) 
           [AAlt p (flatten1 x) | AAlt p x <- alts ]

flatten1 x = x


liftAcrossX :: Ord n => [Bind n] -> [Bind n] -> Exp a n -> Exp a n
liftAcrossX bsDepth bsLevels x
 = let  depth   = length [b | b@(BAnon _) <- bsDepth]
        levels  = length [b | b@(BAnon _) <- bsLevels]
   in   liftAtDepthX levels depth x


liftAcrossT :: Ord n => [Bind n] -> [Bind n] -> Exp a n -> Exp a n
liftAcrossT bsDepth bsLevels x
 = let  depth   = length [b | b@(BAnon _) <- bsDepth]
        levels  = length [b | b@(BAnon _) <- bsLevels]
   in   liftAtDepthT levels depth x


-- | Erase the type of a data binder.
zapX :: Bind n -> Bind n
zapX b = replaceTypeOfBind (tBot kData) b