module DDC.Core.Flow.Transform.Schedule.Kernel ( scheduleKernel , Error (..) , Lifting (..)) where import DDC.Core.Flow.Transform.Schedule.Nest import DDC.Core.Flow.Transform.Schedule.Error import DDC.Core.Flow.Transform.Schedule.Lifting import DDC.Core.Flow.Transform.Schedule.Base import DDC.Core.Flow.Process import DDC.Core.Flow.Procedure import DDC.Core.Flow.Compounds import DDC.Core.Flow.Exp import DDC.Core.Flow.Prim import DDC.Core.Flow.Context -- | Schedule a process kernel into a procedure. -- -- A process kernel is a process with the following restricitions: -- 1) All input series have the same rate. -- 2) A kernel accumulates data into sinks, rather than allocating new values. -- 3) A kernel can be scheduled into a single loop. -- --- The process kernel scheduler can produce code for -- map, reduce, fill, gather, scatter -- -- But not -- fold -- use reduce instead. -- create -- use fill instead. -- pack -- we don't support SIMD masks. -- scheduleKernel :: Lifting -> Process -> Either Error Procedure scheduleKernel lifting (Process { processName = name , processParamFlags = bsParams , processContext = context }) = do -- Lower rates of RateVec parameters. -- We also keep a copy of the original RateVec, -- in case it is used by a cross or a gather. let bsParams_lowered = map (\(flag, BName n t) -> if not flag then case lowerSeriesRate lifting t of Just t' -> (flag, BName n t') Nothing -> (flag, BName n t) else (flag, BName n t)) $ bsParams let bsParamValues = map snd $ filter (not.fst) $ bsParams let c = liftingFactor lifting let frate r _ = return $ tDown c r let fop = scheduleOperator lifting bsParamValues nest <- scheduleContext frate fop context return $ Procedure { procedureName = name , procedureParamFlags = bsParams_lowered , procedureNest = nest } ------------------------------------------------------------------------------- -- | Schedule a single series operator into a loop nest. scheduleOperator :: Lifting -> ScalarEnv -> FillMap -- ^ Map of which operators use which write-to accs -> Operator -- ^ The operator to schedule. -> Either Error ([StmtStart], [StmtBody], [StmtEnd]) scheduleOperator lifting envScalar fills op -- Id ------------------------------------------- | OpId{} <- op = do -- Get binders for the input elements. let Just bResult = elemBindOfSeriesBind (opResultSeries op) >>= liftTypeOfBind lifting let Just uInput = elemBoundOfSeriesBound (opInputSeries op) return ( [] , [ BodyStmt bResult (XVar uInput) ] , [] ) | OpSeriesOfArgument{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tA = opElemType op let BName n t = opResultSeries op let Just t' = lowerSeriesRate lifting t let bS = BName n t' let Just uS = takeSubstBoundOfBind bS let Just tP = procTypeOfSeriesType (typeOfBind bS) let Just bResult = elemBindOfSeriesBind bS >>= liftTypeOfBind lifting -- Body expressions that take the next element from each input series. let bodies = [ BodyStmt bResult (xNextC c tP tK tA (XVar uS) (XVar (UIx 0))) ] return ( [] , bodies , [] ) -- Map ----------------------------------------- | OpMap{} <- op = do -- Bind for the result element. let Just bResultE = elemBindOfSeriesBind (opResultSeries op) >>= liftTypeOfBind lifting -- Bounds for all the input series. let Just usInput = sequence $ map elemBoundOfSeriesBound $ opInputSeriess op -- Bounds for the worker parameters, along with the lifted versions. let bsParam = opWorkerParams op bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam let envLift = zip bsParam bsParam_lifted xWorker_lifted <- liftWorker lifting envScalar envLift $ opWorkerBody op -- Expression to apply the inputs to the worker. let xBody = foldl (\x (b, p) -> XApp (XLam b x) p) (xWorker_lifted) [(b, XVar u) | b <- bsParam_lifted | u <- usInput ] let bodies = [ BodyStmt bResultE xBody ] return ([], bodies, []) -- Fill ---------------------------------------- | OpFill{} <- op = do let c = liftingFactor lifting -- Bound for input element. let Just uInput = elemBoundOfSeriesBound $ opInputSeries op let UName nVec = opTargetVector op let index | Just n <- getAcc fills nVec = xRead tNat $ XVar $ UName $ NameVarMod n "count" | otherwise = XVar $ UIx 0 -- Write to target vector. let bodies = [ BodyStmt (BNone tUnit) (xWriteVectorC c (opElemType op) (XVar $ bufOfVectorName $ opTargetVector op) index (XVar $ uInput)) ] return ([], bodies, []) -- Reduce -------------------------------------- | OpReduce{} <- op = do let c = liftingFactor lifting let tA = typeOfBind $ opWorkerParamElem op -- Evaluate the zero value and initialize the vector accumulator. let UName nRef = opTargetRef op let nAccZero = NameVarMod nRef "zero" let bAccZero = BName nAccZero tA let uAccZero = UName nAccZero let nAccVec = NameVarMod nRef "vec" let uAccVec = UName nAccVec let starts = [ StartStmt bAccZero (opZero op) , StartAcc nAccVec (tVec c tA) (xvRep c tA (XVar uAccZero)) ] -- Bound for input element. let Just uInput = elemBoundOfSeriesBound $ opInputSeries op -- Bound for intermediate accumulator value. let nAccVal = NameVarMod nRef "val" let uAccVal = UName nAccVal let bAccVal = BName nAccVal (tVec c tA) -- Lift the worker function. let bsParam = [ opWorkerParamAcc op, opWorkerParamElem op ] bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam let envLift = zip bsParam bsParam_lifted xWorker_lifted <- liftWorker lifting envScalar envLift $ opWorkerBody op -- Read the current accumulator value and update it with the worker. let xBody_lifted x1 x2 = XApp (XApp ( XLam (opWorkerParamAcc op) $ XLam (opWorkerParamElem op) (xWorker_lifted)) x1) x2 let bodies = [ BodyAccRead nAccVec (tVec c tA) bAccVal , BodyAccWrite nAccVec (tVec c tA) (xBody_lifted (XVar uAccVal) (XVar uInput)) ] -- Read back the vector accumulator and to a final fold over its parts. let nAccResult = NameVarMod nRef "res" let bAccResult = BName nAccResult (tVec c tA) let uAccResult = UName nAccResult let bPart (i :: Int) = BName (NameVarMod nAccResult (show i)) tA let uPart (i :: Int) = UName (NameVarMod nAccResult (show i)) let nAccInit = NameVarMod nRef "init" let xBody x1 x2 = XApp (XApp ( XLam (opWorkerParamAcc op) $ XLam (opWorkerParamElem op) (opWorkerBody op)) x1) x2 let ends = [ EndStmt bAccResult (xRead (tVec c tA) (XVar uAccVec)) , EndStmt (BName nAccInit tA) (xRead tA (XVar $ opTargetRef op)) ] ++ [ EndStmt (bPart 0) (xBody (XVar $ UName nAccInit) (xvProj c 0 tA (XVar uAccResult))) ] ++ [ EndStmt (bPart i) (xBody (XVar (uPart (i - 1))) (xvProj c i tA (XVar uAccResult))) | i <- [1.. c - 1]] -- Write final value to destination. ++ [ EndStmt (BNone tUnit) (xWrite tA (XVar $ opTargetRef op) (XVar $ uPart (c - 1))) ] -- Bind final unit value. ++ [ EndStmt (opResultBind op) xUnit ] return (starts, bodies, ends) -- Gather -------------------------------------- | OpGather{} <- op = do let c = liftingFactor lifting -- Bind for result element. let Just bResultE = elemBindOfSeriesBind (opResultBind op) >>= liftTypeOfBind lifting -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) -- Read from vector. let bodies = [ BodyStmt bResultE (xvGather c (opVectorRate op) (opElemType op) (XVar $ opSourceVector op) (XVar $ uIndex)) ] return ([], bodies, []) -- Scatter ------------------------------------- | OpScatter{} <- op = do let c = liftingFactor lifting -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) -- Bound of source elements. let Just uElem = elemBoundOfSeriesBound (opSourceElems op) -- Read from vector. let bodies = [ BodyStmt (BNone tUnit) (xvScatter c (opElemType op) (XVar $ opTargetVector op) (XVar $ uIndex) (XVar $ uElem)) ] -- Bind final unit value. let ends = [ EndStmt (opResultBind op) xUnit ] return ([], bodies, ends) -- Unsupported --------------------------------- | otherwise = Left $ ErrorUnsupported op liftTypeOfBindM :: Lifting -> Bind Name -> Either Error (Bind Name) liftTypeOfBindM lifting b = case liftTypeOfBind lifting b of Just b' -> return b' _ -> Left $ ErrorCannotLiftType (typeOfBind b)