module DDC.Core.Flow.Transform.Schedule.Scalar
(scheduleScalar)
where
import DDC.Core.Flow.Transform.Schedule.Nest
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Process
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Prim.OpStore
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Context
scheduleScalar :: Process -> Either Error Procedure
scheduleScalar
(Process { processName = name
, processParamFlags = bsParams
, processContext = context })
= do
nest <- scheduleContext (\r _ -> return r)
(scheduleOperator context) context
return $ Procedure
{ procedureName = name
, procedureParamFlags = bsParams
, procedureNest = nest }
scheduleOperator
:: Context
-> FillMap
-> Operator
-> Either Error ([StmtStart], [StmtBody], [StmtEnd])
scheduleOperator _ctx fills op
| OpId{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
return ( []
, [ BodyStmt bResult (XVar uInput) ]
, [] )
| OpSeriesOfRateVec{} <- op
= do let tK = opInputRate op
let tA = opElemType op
let bS = opResultSeries op
let uInput = opInputRateVec op
let Just uS = takeSubstBoundOfBind bS
let Just tP = procTypeOfSeriesType (typeOfBind bS)
let Just bResult = elemBindOfSeriesBind bS
let starts
= [ StartStmt bS
(xSeriesOfRateVec tP tK tA (XVar uInput)) ]
let bodies
= [ BodyStmt bResult
(xNext tP tK tA (XVar uS) (XVar (UIx 0))) ]
return ( starts
, bodies
, [] )
| OpSeriesOfArgument{} <- op
= do let tK = opInputRate op
let tA = opElemType op
let bS = opResultSeries op
let Just uS = takeSubstBoundOfBind bS
let Just tP = procTypeOfSeriesType (typeOfBind bS)
let Just bResult = elemBindOfSeriesBind bS
let bodies
= [ BodyStmt bResult
(xNext tP tK tA (XVar uS) (XVar (UIx 0))) ]
return ( []
, bodies
, [] )
| OpRep{} <- op
= do
let BName nResult _ = opResultSeries op
let nVal = NameVarMod nResult "val"
let uVal = UName nVal
let bVal = BName nVal (opElemType op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let starts
= [ StartStmt bVal (opInputExp op) ]
let bodies
= [ BodyStmt bResult (XVar uVal) ]
return (starts, bodies, [])
| OpReps{} <- op
= do
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let bodies
= [ BodyStmt bResult
(XVar uInput)]
return ([], bodies, [])
| OpIndices{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let bodies
= [ BodyStmt bResult
(XVar (UIx 0)) ]
return ([], bodies, [])
| OpFill{} <- op
= do
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let UName nVec = opTargetVector op
let index
| Just n <- getAcc fills nVec
= xRead tNat
$ XVar $ UName $ NameVarMod n "count"
| otherwise
= XVar $ UIx 0
let bodies
= [ BodyVecWrite
nVec
(opElemType op)
index
(XVar uInput) ]
let inc
| Just n <- getAcc fills nVec
, n == nVec
= [ BodyAccWrite
(NameVarMod n "count")
tNat
(xIncrement index) ]
| otherwise
= []
return ([], bodies ++ inc, [])
| OpGather{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultBind op)
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let buf = xBufOfRateVec (opVectorRate op) (opElemType op)
(XVar $ opSourceVector op)
let bodies = [ BodyStmt bResult
(xReadVector
(opElemType op)
buf
(XVar $ uIndex)) ]
return ([], bodies, [])
| OpScatter{} <- op
= do
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just uElem = elemBoundOfSeriesBound (opSourceElems op)
let bodies = [ BodyStmt (BNone tUnit)
(xWriteVector
(opElemType op)
(XVar $ bufOfVectorName $ opTargetVector op)
(XVar $ uIndex) (XVar $ uElem)) ]
let ends = [ EndStmt (opResultBind op)
xUnit ]
return ([], bodies, ends)
| OpMap{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just usInput = sequence
$ map elemBoundOfSeriesBound
$ opInputSeriess op
let xBody
= foldl (\x (b, p) -> XApp (XLam b x) p)
(opWorkerBody op)
[(b, XVar u)
| b <- opWorkerParams op
| u <- usInput ]
let bodies
= [ BodyStmt bResult xBody ]
return ([], bodies, [])
| OpPack{} <- op
= do
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let bodies
= [ BodyStmt bResult
(XVar uInput)]
return ([], bodies, [])
| OpGenerate{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let xBody
= XApp ( XLam (opWorkerParamIndex op)
(opWorkerBody op))
(XVar (UIx 0))
let bodies
= [ BodyStmt bResult xBody ]
return ([], bodies, [])
| OpReduce{} <- op
= do
let UName nResult = opTargetRef op
let nAcc = NameVarMod nResult "acc"
let tAcc = typeOfBind (opWorkerParamAcc op)
let nAccInit = NameVarMod nResult "init"
let starts
= [ StartStmt (BName nAccInit tAcc)
(xRead tAcc (XVar $ opTargetRef op))
, StartAcc nAcc tAcc (XVar (UName nAccInit)) ]
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let nAccVal = NameVarMod nResult "val"
let uAccVal = UName nAccVal
let bAccVal = BName nAccVal tAcc
let xBody x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(opWorkerBody op))
x1)
x2
let bodies
= [ BodyAccRead nAcc tAcc bAccVal
, BodyAccWrite nAcc tAcc
(xBody (XVar uAccVal)
(XVar uInput)) ]
let nAccRes = NameVarMod nResult "res"
let ends
= [ EndAcc nAccRes tAcc nAcc
, EndStmt (BNone tUnit)
(xWrite tAcc (XVar $ opTargetRef op)
(XVar $ UName nAccRes)) ]
return (starts, bodies, ends)
| otherwise
= Left $ ErrorUnsupported op
xIncrement :: Exp a Name -> Exp a Name
xIncrement xx
= xApps (XVar (UPrim (NamePrimArith PrimArithAdd)
(typePrimArith PrimArithAdd)))
[ XType tNat, xx, XCon (dcNat 1) ]