module DDC.Core.Flow.Transform.Extract
( extractModule
, extractProcedure)
where
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Prim.OpStore
import DDC.Core.Flow.Exp
import DDC.Core.Transform.Annotate
import DDC.Core.Module
extractModule :: ModuleF -> [Procedure] -> ModuleF
extractModule orig procs
= orig
{ moduleBody = annotate () $ extractTop procs }
extractTop :: [Procedure] -> ExpF
extractTop procs
= XLet (LRec (map extractProcedure procs)) xUnit
extractProcedure :: Procedure -> (Bind Name, ExpF)
extractProcedure (Procedure n params nest)
= let
tyOfFlags (True, b) rest
= TForall b rest
tyOfFlags (False, b) rest
= tFun (typeOfBind b) rest
tBody = foldr tyOfFlags tUnit $ params
in ( BName n tBody
, makeXLamFlags params
$ xLets (concatMap vecBuffers $ map snd $ filter (not.fst) params)
$ extractNest nest xUnit )
vecBuffers
:: BindF
-> [LetsF]
vecBuffers (BName n t)
| isVectorType t
, Just (_, [t']) <- takePrimTyConApps t
= [ LLet (BName (NameVarMod n "buf") (tBuffer t'))
(xBufOfVector t' $ XVar $ UName n) ]
vecBuffers _
= []
extractNest
:: Nest
-> ExpF
-> ExpF
extractNest nest xResult
= xLets (extractLoop nest) xResult
extractLoop :: Nest -> [LetsF]
extractLoop (NestLoop tRate starts bodys inner ends)
= let
lsStart = concatMap extractStmtStart starts
lLoop = LLet (BNone tUnit)
(xApps (XVar (UPrim (NameOpControl OpControlLoop)
(typeOpControl OpControlLoop)))
[ XType tRate
, xBody ])
xBody = XLam (BAnon tNat)
$ xLets (lsBody ++ lsInner)
xUnit
lsBody = concatMap extractStmtBody bodys
lsInner = extractLoop inner
lsEnd = concatMap extractStmtEnd ends
in lsStart ++ [lLoop] ++ lsEnd
extractLoop (NestGuard _tRateOuter _tRateInner uFlags stmtsBody nested)
= let
UName nFlags = uFlags
nFlag = NameVarMod nFlags "elem"
xFlag = XVar (UName nFlag)
xBody = xGuard xFlag
( XLam (BNone tUnit)
$ xLets (lsBody ++ lsNested) xUnit)
lsBody = concatMap extractStmtBody stmtsBody
lsNested = extractLoop nested
in [LLet (BNone tUnit) xBody]
extractLoop (NestSegment _tRateOuter _tRateInner uLengths stmtsBody nested)
= let
UName nLengths = uLengths
nLength = NameVarMod nLengths "elem"
xLength = XVar (UName nLength)
xBody = xSegment xLength
( XLam (BAnon tNat)
$ xLets (lsBody ++ lsNested) xUnit)
lsBody = concatMap extractStmtBody stmtsBody
lsNested = extractLoop nested
in [LLet (BNone tUnit) xBody]
extractLoop NestEmpty
= []
extractLoop (NestList nests)
= concatMap extractLoop nests
extractStmtStart :: StmtStart -> [LetsF]
extractStmtStart ss
= case ss of
StartStmt b x
-> [LLet b x]
StartVecNew nVec tElem tRate'
-> [LLet (BName nVec (tVector tElem))
(xNewVectorR tElem tRate') ]
StartAcc n t x
-> [LLet (BName n (tRef t))
(xNew t x)]
extractStmtBody :: StmtBody -> [LetsF]
extractStmtBody sb
= case sb of
BodyStmt b x
-> [ LLet b x ]
BodyVecWrite nVec tElem xIx xVal
-> [ LLet (BNone tUnit)
(xWriteVector tElem (XVar (UName $ NameVarMod nVec "buf")) xIx xVal)]
BodyAccRead n t bVar
-> [ LLet bVar
(xRead t (XVar (UName n))) ]
BodyAccWrite nAcc tElem xWorker
-> [ LLet (BNone tUnit)
(xWrite tElem (XVar (UName nAcc)) xWorker)]
extractStmtEnd :: StmtEnd -> [LetsF]
extractStmtEnd se
= case se of
EndStmt b x
-> [LLet b x]
EndAcc n t nAcc
-> [LLet (BName n t)
(xRead t (XVar (UName nAcc))) ]
EndVecTrunc nVec tElem uCounter
-> let
xCounter = xRead tNat (XVar uCounter)
xVec = XVar (UName nVec)
in [ LLet (BAnon tNat)
xCounter
, LLet (BNone tUnit)
(xTruncVector tElem (XVar (UIx 0)) xVec) ]