module DDC.Core.Flow.Transform.Prep
        (prepModule)
where
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Prim.TyConPrim
import DDC.Core.Compounds
import DDC.Core.Module
import DDC.Core.Exp
import Control.Monad.State.Strict
import Data.Map                 (Map)
import qualified Data.Map       as Map
import DDC.Type.Env             (TypeEnv)
import qualified DDC.Type.Env   as Env


-- | Prepare a module for lowering.
--   We need all worker functions passed to flow operators to be eta-expanded
--   and for their parameters to have real names.
prepModule 
        ::  Module a Name 
        -> (Module a Name, Map Name [Type Name])

prepModule mm
 = do   runState (prepModuleM mm) Map.empty


prepModuleM :: Module a Name -> PrepM (Module a Name)
prepModuleM mm
 = do   xBody'  <- prepX Env.empty $ moduleBody mm
        return  $  mm { moduleBody = xBody' }


-- Do a bottom-up rewrite,
--  on the way up remember names of variables that are passed as workers 
--  to flow operators, then eta-expand bindings with those names.
-- Record the environment of let-bound expressions, to know whether to 
--  eta-expand in their definition or at the callsite.
prepX   :: TypeEnv Name -> Exp a Name -> PrepM (Exp a Name)
prepX tenv xx
 = let down     = prepX tenv
   in  case xx of
        -- MapN
        XApp{}
         | Just (XVar _ u, xsArgs)              <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMap n)) _   <- u
         , _xTR : xsArgs2                       <- xsArgs
         , (xsA, xsArgs3)                       <- splitAt (n + 1) xsArgs2
         , tsA                                  <- [t | XType t <- xsA]
         , XVar _ (UName nWorker) : _           <- xsArgs3
         , Env.member (UName nWorker) tenv
         -> do  addWorkerArgs nWorker (take n tsA)
                return xx

        -- Worker passed to map, but not let-bound.
        -- Eta-expand in-place.
        XApp{}
         | Just (xmap@(XVar _ u), args@[_,  XType tA, XType _tB, f@(XVar a _), _])
                                                <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMap 1)) _   <- u
         -> do  let f'    = xEtaExpand a f [tA]
                    args' = take 3 args ++ [f'] ++ [last args]
                return $ xApps a xmap args'

        -- Detect workers passed to folds.
        XApp{}
         | Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
                                               <- takeXApps xx
         , UPrim (NameOpFlow OpFlowFold) _     <- u
         -> do   addWorkerArgs n [tA, tB]
                 return xx

        -- FoldIndex
        XApp{}
         | Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
                                                <- takeXApps xx
         , UPrim (NameOpFlow OpFlowFoldIndex) _ <- u
         -> do   addWorkerArgs n [tInt, tA, tB]
                 return xx

        -- Detect workers passed to mkSels
        XApp{}
         | Just (XVar _ u, [XType _tK1, XType _tA, _, XVar _ (UName n)])
                                                <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMkSel _)) _ <- u
         -> do  addWorkerArgs n []
                return xx

        -- Bottom-up transform boilerplate.
        XVar{}          -> return xx
        XCon{}          -> return xx
        XLAM  a b x     -> liftM3 XLAM  (return a) (return b) (down x)
        XLam  a b x     -> liftM3 XLam  (return a) (return b) (down x)
        XApp  a x1 x2   -> liftM3 XApp  (return a) (down x1)  (down x2)

        XLet  a lts x   
         -> do  -- Slurp binds from lets, add to tenv
                let tenv' = Env.extends (valwitBindsOfLets lts) tenv
                x'      <- prepX tenv' x

                -- Use old tenv for the binders
                lts'    <- prepLts tenv a lts
                return  $  XLet a lts' x'

        XCase a x alts  -> liftM3 XCase (return a) (down x)   (mapM (prepAlt tenv) alts)
        XCast a c x     -> liftM3 XCast (return a) (return c) (down x)
        XType{}         -> return xx
        XWitness{}      -> return xx


-- Prepare let bindings for lowering.
prepLts :: TypeEnv Name -> a -> Lets a Name -> PrepM (Lets a Name)
prepLts tenv a lts
 = case lts of
        LLet b@(BName n _) x
         -> do  x'      <- prepX tenv x

                mArgs   <- lookupWorkerArgs n
                case mArgs of
                 Just tsArgs
                  |  length tsArgs > 0
                   -> return $ LLet b $ xEtaExpand a x' tsArgs

                 _ -> return $ LLet b x'

        LLet b x
         -> do  x'      <- prepX tenv x
                return  $ LLet b x'

        LRec bxs
         -> do  let (bs, xs) = unzip bxs
                let tenv'    = Env.extends bs tenv
                xs'     <- mapM (prepX tenv') xs
                return  $ LRec $ zip bs xs'

        LLetRegions{}   -> return lts
        LWithRegion{}   -> return lts


-- Prepare case alternative for lowering.
prepAlt :: TypeEnv Name -> Alt a Name -> PrepM (Alt a Name)
prepAlt tenv (AAlt w x)
        = liftM (AAlt w) (prepX tenv x)


xEtaExpand :: a -> Exp a Name -> [Type Name] -> Exp a Name
xEtaExpand a x tys
 = xLams a    (map BAnon tys)
 $ xApps a x  [ XVar a (UIx (length tys - 1 - ix))
              | ix <- [0 ..  length tys - 1] ]


-- State ----------------------------------------------------------------------
type PrepS      = Map   Name [Type Name]
type PrepM      = State PrepS


-- | Record this name as being of a worker function.
addWorkerArgs   :: Name -> [Type Name] -> PrepM ()
addWorkerArgs name tsParam
        = modify $ Map.insert name tsParam


-- | Check whether this name corresponds to a worker function.
lookupWorkerArgs    :: Name -> PrepM (Maybe [Type Name])
lookupWorkerArgs name
 = do   names   <- get
        return  $ Map.lookup name names