module DDC.Core.Flow.Transform.Schedule.Kernel
( scheduleKernel
, Error (..)
, Lifting (..))
where
import DDC.Core.Flow.Transform.Schedule.Nest
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Transform.Schedule.Lifting
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Process
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Context
scheduleKernel :: Lifting -> Process -> Either Error Procedure
scheduleKernel
lifting
(Process { processName = name
, processParamFlags = bsParams
, processContext = context })
= do
let bsParams_lowered
= map (\(flag, BName n t)
-> if not flag
then case lowerSeriesRate lifting t of
Just t'
-> (flag, BName n t')
Nothing
-> (flag, BName n t)
else (flag, BName n t))
$ bsParams
let bsParamValues
= map snd
$ filter (not.fst)
$ bsParams
let c = liftingFactor lifting
let frate r _ = return $ tDown c r
let fop = scheduleOperator lifting bsParamValues
nest <- scheduleContext frate fop context
return $ Procedure
{ procedureName = name
, procedureParamFlags = bsParams_lowered
, procedureNest = nest }
scheduleOperator
:: Lifting
-> ScalarEnv
-> FillMap
-> Operator
-> Either Error ([StmtStart], [StmtBody], [StmtEnd])
scheduleOperator lifting envScalar fills op
| OpId{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
>>= liftTypeOfBind lifting
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
return ( []
, [ BodyStmt bResult (XVar uInput) ]
, [] )
| OpSeriesOfArgument{} <- op
= do let c = liftingFactor lifting
let tK = opInputRate op
let tA = opElemType op
let BName n t = opResultSeries op
let Just t' = lowerSeriesRate lifting t
let bS = BName n t'
let Just uS = takeSubstBoundOfBind bS
let Just tP = procTypeOfSeriesType (typeOfBind bS)
let Just bResult = elemBindOfSeriesBind bS
>>= liftTypeOfBind lifting
let bodies
= [ BodyStmt bResult
(xNextC c tP tK tA (XVar uS) (XVar (UIx 0))) ]
return ( []
, bodies
, [] )
| OpMap{} <- op
= do
let Just bResultE = elemBindOfSeriesBind (opResultSeries op)
>>= liftTypeOfBind lifting
let Just usInput = sequence
$ map elemBoundOfSeriesBound
$ opInputSeriess op
let bsParam = opWorkerParams op
bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam
let envLift = zip bsParam bsParam_lifted
xWorker_lifted <- liftWorker lifting envScalar envLift
$ opWorkerBody op
let xBody = foldl (\x (b, p) -> XApp (XLam b x) p)
(xWorker_lifted)
[(b, XVar u)
| b <- bsParam_lifted
| u <- usInput ]
let bodies = [ BodyStmt bResultE xBody ]
return ([], bodies, [])
| OpFill{} <- op
= do let c = liftingFactor lifting
let Just uInput = elemBoundOfSeriesBound
$ opInputSeries op
let UName nVec = opTargetVector op
let index
| Just n <- getAcc fills nVec
= xRead tNat
$ XVar $ UName $ NameVarMod n "count"
| otherwise
= XVar $ UIx 0
let bodies = [ BodyStmt (BNone tUnit)
(xWriteVectorC c
(opElemType op)
(XVar $ bufOfVectorName $ opTargetVector op)
index
(XVar $ uInput)) ]
return ([], bodies, [])
| OpReduce{} <- op
= do let c = liftingFactor lifting
let tA = typeOfBind $ opWorkerParamElem op
let UName nRef = opTargetRef op
let nAccZero = NameVarMod nRef "zero"
let bAccZero = BName nAccZero tA
let uAccZero = UName nAccZero
let nAccVec = NameVarMod nRef "vec"
let uAccVec = UName nAccVec
let starts
= [ StartStmt bAccZero (opZero op)
, StartAcc nAccVec
(tVec c tA)
(xvRep c tA (XVar uAccZero)) ]
let Just uInput = elemBoundOfSeriesBound
$ opInputSeries op
let nAccVal = NameVarMod nRef "val"
let uAccVal = UName nAccVal
let bAccVal = BName nAccVal (tVec c tA)
let bsParam = [ opWorkerParamAcc op, opWorkerParamElem op ]
bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam
let envLift = zip bsParam bsParam_lifted
xWorker_lifted <- liftWorker lifting envScalar envLift
$ opWorkerBody op
let xBody_lifted x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(xWorker_lifted))
x1)
x2
let bodies
= [ BodyAccRead nAccVec (tVec c tA) bAccVal
, BodyAccWrite nAccVec (tVec c tA)
(xBody_lifted (XVar uAccVal) (XVar uInput)) ]
let nAccResult = NameVarMod nRef "res"
let bAccResult = BName nAccResult (tVec c tA)
let uAccResult = UName nAccResult
let bPart (i :: Int) = BName (NameVarMod nAccResult (show i)) tA
let uPart (i :: Int) = UName (NameVarMod nAccResult (show i))
let nAccInit = NameVarMod nRef "init"
let xBody x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(opWorkerBody op))
x1)
x2
let ends
= [ EndStmt bAccResult
(xRead (tVec c tA) (XVar uAccVec))
, EndStmt (BName nAccInit tA)
(xRead tA (XVar $ opTargetRef op)) ]
++ [ EndStmt (bPart 0)
(xBody (XVar $ UName nAccInit)
(xvProj c 0 tA (XVar uAccResult))) ]
++ [ EndStmt (bPart i)
(xBody (XVar (uPart (i 1)))
(xvProj c i tA (XVar uAccResult)))
| i <- [1.. c 1]]
++ [ EndStmt (BNone tUnit)
(xWrite tA (XVar $ opTargetRef op)
(XVar $ uPart (c 1))) ]
++ [ EndStmt (opResultBind op)
xUnit ]
return (starts, bodies, ends)
| OpGather{} <- op
= do
let c = liftingFactor lifting
let Just bResultE = elemBindOfSeriesBind (opResultBind op)
>>= liftTypeOfBind lifting
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let bodies = [ BodyStmt bResultE
(xvGather c
(opVectorRate op)
(opElemType op)
(XVar $ opSourceVector op)
(XVar $ uIndex)) ]
return ([], bodies, [])
| OpScatter{} <- op
= do
let c = liftingFactor lifting
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just uElem = elemBoundOfSeriesBound (opSourceElems op)
let bodies = [ BodyStmt (BNone tUnit)
(xvScatter c
(opElemType op)
(XVar $ opTargetVector op)
(XVar $ uIndex) (XVar $ uElem)) ]
let ends = [ EndStmt (opResultBind op)
xUnit ]
return ([], bodies, ends)
| otherwise
= Left $ ErrorUnsupported op
liftTypeOfBindM :: Lifting -> Bind Name -> Either Error (Bind Name)
liftTypeOfBindM lifting b
= case liftTypeOfBind lifting b of
Just b' -> return b'
_ -> Left $ ErrorCannotLiftType (typeOfBind b)