module DDC.Core.Flow.Transform.Prep
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Prim.TyConPrim
import DDC.Core.Compounds
import DDC.Core.Module
import DDC.Core.Exp
import Control.Monad.State.Strict
import Data.Map                 (Map)
import qualified Data.Map       as Map
import DDC.Type.Env             (TypeEnv)
import qualified DDC.Type.Env   as Env

-- | Prepare a module for lowering.
--   We need all worker functions passed to flow operators to be eta-expanded
--   and for their parameters to have real names.
        ::  Module a Name 
        -> (Module a Name, Map Name [Type Name])

prepModule mm
 = do   runState (prepModuleM mm) Map.empty

prepModuleM :: Module a Name -> PrepM (Module a Name)
prepModuleM mm
 = do   xBody'  <- prepX Env.empty $ moduleBody mm
        return  $  mm { moduleBody = xBody' }

-- Do a bottom-up rewrite,
--  on the way up remember names of variables that are passed as workers 
--  to flow operators, then eta-expand bindings with those names.
-- Record the environment of let-bound expressions, to know whether to 
--  eta-expand in their definition or at the callsite.
prepX   :: TypeEnv Name -> Exp a Name -> PrepM (Exp a Name)
prepX tenv xx
 = let down     = prepX tenv
   in  case xx of
        -- MapN
         | Just (XVar _ u, xsArgs)              <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMap n)) _   <- u
         , _xTR : xsArgs2                       <- xsArgs
         , (xsA, xsArgs3)                       <- splitAt (n + 1) xsArgs2
         , tsA                                  <- [t | XType t <- xsA]
         , XVar _ (UName nWorker) : _           <- xsArgs3
         , Env.member (UName nWorker) tenv
         -> do  addWorkerArgs nWorker (take n tsA)
                return xx

        -- Worker passed to map, but not let-bound.
        -- Eta-expand in-place.
         | Just (xmap@(XVar _ u), args@[_,  XType tA, XType _tB, f@(XVar a _), _])
                                                <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMap 1)) _   <- u
         -> do  let f'    = xEtaExpand a f [tA]
                    args' = take 3 args ++ [f'] ++ [last args]
                return $ xApps a xmap args'

        -- Detect workers passed to folds.
         | Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
                                               <- takeXApps xx
         , UPrim (NameOpFlow OpFlowFold) _     <- u
         -> do   addWorkerArgs n [tA, tB]
                 return xx

        -- FoldIndex
         | Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
                                                <- takeXApps xx
         , UPrim (NameOpFlow OpFlowFoldIndex) _ <- u
         -> do   addWorkerArgs n [tInt, tA, tB]
                 return xx

        -- Detect workers passed to mkSels
         | Just (XVar _ u, [XType _tK1, XType _tA, _, XVar _ (UName n)])
                                                <- takeXApps xx
         , UPrim (NameOpFlow (OpFlowMkSel _)) _ <- u
         -> do  addWorkerArgs n []
                return xx

        -- Bottom-up transform boilerplate.
        XVar{}          -> return xx
        XCon{}          -> return xx
        XLAM  a b x     -> liftM3 XLAM  (return a) (return b) (down x)
        XLam  a b x     -> liftM3 XLam  (return a) (return b) (down x)
        XApp  a x1 x2   -> liftM3 XApp  (return a) (down x1)  (down x2)

        XLet  a lts x   
         -> do  -- Slurp binds from lets, add to tenv
                let tenv' = Env.extends (valwitBindsOfLets lts) tenv
                x'      <- prepX tenv' x

                -- Use old tenv for the binders
                lts'    <- prepLts tenv a lts
                return  $  XLet a lts' x'

        XCase a x alts  -> liftM3 XCase (return a) (down x)   (mapM (prepAlt tenv) alts)
        XCast a c x     -> liftM3 XCast (return a) (return c) (down x)
        XType{}         -> return xx
        XWitness{}      -> return xx

-- Prepare let bindings for lowering.
prepLts :: TypeEnv Name -> a -> Lets a Name -> PrepM (Lets a Name)
prepLts tenv a lts
 = case lts of
        LLet b@(BName n _) x
         -> do  x'      <- prepX tenv x

                mArgs   <- lookupWorkerArgs n
                case mArgs of
                 Just tsArgs
                  |  length tsArgs > 0
                   -> return $ LLet b $ xEtaExpand a x' tsArgs

                 _ -> return $ LLet b x'

        LLet b x
         -> do  x'      <- prepX tenv x
                return  $ LLet b x'

        LRec bxs
         -> do  let (bs, xs) = unzip bxs
                let tenv'    = Env.extends bs tenv
                xs'     <- mapM (prepX tenv') xs
                return  $ LRec $ zip bs xs'

        LLetRegions{}   -> return lts
        LWithRegion{}   -> return lts

-- Prepare case alternative for lowering.
prepAlt :: TypeEnv Name -> Alt a Name -> PrepM (Alt a Name)
prepAlt tenv (AAlt w x)
        = liftM (AAlt w) (prepX tenv x)

xEtaExpand :: a -> Exp a Name -> [Type Name] -> Exp a Name
xEtaExpand a x tys
 = xLams a    (map BAnon tys)
 $ xApps a x  [ XVar a (UIx (length tys - 1 - ix))
              | ix <- [0 ..  length tys - 1] ]

-- State ----------------------------------------------------------------------
type PrepS      = Map   Name [Type Name]
type PrepM      = State PrepS

-- | Record this name as being of a worker function.
addWorkerArgs   :: Name -> [Type Name] -> PrepM ()
addWorkerArgs name tsParam
        = modify $ Map.insert name tsParam

-- | Check whether this name corresponds to a worker function.
lookupWorkerArgs    :: Name -> PrepM (Maybe [Type Name])
lookupWorkerArgs name
 = do   names   <- get
        return  $ Map.lookup name names