module DDC.Core.Flow.Transform.Extract ( extractModule , extractProcedure) where import DDC.Core.Flow.Compounds import DDC.Core.Flow.Procedure import DDC.Core.Flow.Prim import DDC.Core.Flow.Exp import DDC.Core.Transform.Annotate import DDC.Core.Module -- | Extract a core module from some stream procedures. -- This produces vanilla core code again. extractModule :: ModuleF -> [Procedure] -> ModuleF extractModule orig procs = orig { moduleBody = annotate () $ extractTop procs } -- | Extract a top level binding from a procedure. extractTop :: [Procedure] -> ExpF extractTop procs = XLet (LRec (map extractProcedure procs)) xUnit -- | Extract code for a whole procedure. extractProcedure :: Procedure -> (Bind Name, ExpF) extractProcedure (Procedure n bsParam xsParam nest) = let tBody = foldr tFun tUnit $ map typeOfBind xsParam tQuant = foldr TForall tBody $ bsParam in ( BName n tQuant , xLAMs bsParam $ xLams xsParam $ extractNest nest xUnit ) ------------------------------------------------------------------------------- -- | Extract code for a loop nest. extractNest :: Nest -- ^ Loops to run in sequence. -> ExpF -- ^ Final result of procedure. -> ExpF extractNest nest xResult = xLets (extractLoop nest) xResult ------------------------------------------------------------------------------- -- | Extract code for a possibly nested loop. extractLoop :: Nest -> [LetsF] -- Code in the top-level loop context. extractLoop (NestLoop tRate starts bodys inner ends _result) = let -- Starting statements. lsStart = concatMap extractStmtStart starts -- The loop itself. lLoop = LLet (BNone tUnit) (xApps (XVar (UPrim (NameOpControl OpControlLoop) (typeOpControl OpControlLoop))) [ XType tRate -- loop rate , xBody ]) -- loop body -- The worker passed to the loop# combinator. xBody = XLam (BAnon tNat) -- loop counter. $ xLets (lsBody ++ lsInner) xUnit -- Process the elements. lsBody = concatMap extractStmtBody bodys -- Handle inner contexts. lsInner = extractLoop inner -- Ending statements lsEnd = concatMap extractStmtEnd ends in lsStart ++ [lLoop] ++ lsEnd -- Code in a guard context. extractLoop (NestGuard _tRateOuter tRateInner uFlags stmtsBody nested) = let -- Get the name of a single flag from the series of flags. UName nFlags = uFlags nFlag = NameVarMod nFlags "elem" xFlag = XVar (UName nFlag) -- Get the name of the entry counter. TVar (UName nK) = tRateInner uCounter = UName (NameVarMod nK "count") xBody = xGuard (XVar uCounter) xFlag ( XLam (BAnon tNat) $ xLets (lsBody ++ lsNested) xUnit) -- Statements in the guard context. lsBody = concatMap extractStmtBody stmtsBody -- Nested contexts. lsNested = extractLoop nested in [LLet (BNone tUnit) xBody] -- Code in a segment context. extractLoop (NestSegment _tRateOuter tRateInner uLengths stmtsBody nested) = let -- Get the name of a single segment length from the series of lengths. UName nLengths = uLengths nLength = NameVarMod nLengths "elem" xLength = XVar (UName nLength) -- Get the name of the entry counter. TVar (UName nK) = tRateInner uCounter = UName (NameVarMod nK "count") xBody = xSegment (XVar uCounter) xLength ( XLam (BAnon tNat) -- Index into current segment. $ XLam (BAnon tNat) -- Index into overall result series. $ xLets (lsBody ++ lsNested) xUnit) -- Statements in the segment context. lsBody = concatMap extractStmtBody stmtsBody -- Nested contexts. lsNested = extractLoop nested in [LLet (BNone tUnit) xBody] extractLoop NestEmpty = [] extractLoop (NestList nests) = concatMap extractLoop nests ------------------------------------------------------------------------------- -- | Extract loop starting code. -- This comes before the main loop. extractStmtStart :: StmtStart -> [LetsF] extractStmtStart ss = case ss of -- Evaluate a pure expression. StartStmt b x -> [LLet b x] -- Allocate a new vector. StartVecNew nVec tElem tRate' -> [LLet (BName nVec (tVector tElem)) (xNewVectorR tElem tRate') ] -- Initialise the accumulator for a reduction operation. StartAcc n t x -> [LLet (BName n (tRef t)) (xNew t x)] ------------------------------------------------------------------------------- -- | Extract loop body code. extractStmtBody :: StmtBody -> [LetsF] extractStmtBody sb = case sb of BodyStmt b x -> [ LLet b x ] -- Write to a vector. BodyVecWrite nVec tElem xIx xVal -> [ LLet (BNone tUnit) (xWriteVector tElem (XVar (UName nVec)) xIx xVal)] -- Read from an accumulator. BodyAccRead n t bVar -> [ LLet bVar (xRead t (XVar (UName n))) ] -- Accumulate an element from a stream. BodyAccWrite nAcc tElem xWorker -> [ LLet (BNone tUnit) (xWrite tElem (XVar (UName nAcc)) xWorker)] ------------------------------------------------------------------------------- -- | Extract loop ending code. -- This comes after the main loop. extractStmtEnd :: StmtEnd -> [LetsF] extractStmtEnd se = case se of EndStmt b x -> [LLet b x] -- Read the accumulator of a reduction operation. EndAcc n t nAcc -> [LLet (BName n t) (xRead t (XVar (UName nAcc))) ] -- Truncate a vector down to its final size. EndVecTrunc nVec tElem tRate -> let -- Get the name of the counter. TVar (UName nK) = tRate uCounter = UName (NameVarMod nK "count") xCounter = xRead tNat (XVar uCounter) xVec = XVar (UName nVec) -- Read the counter in a let since it will need to be threaded in [ LLet (BAnon tNat) xCounter , LLet (BNone tUnit) (xTruncVector tElem (XVar (UIx 0)) xVec) ]