module DDC.Core.Flow.Transform.Schedule.Kernel ( scheduleKernel , Error (..) , Lifting (..)) where import DDC.Core.Flow.Transform.Schedule.Nest import DDC.Core.Flow.Transform.Schedule.Error import DDC.Core.Flow.Transform.Schedule.Lifting import DDC.Core.Flow.Transform.Schedule.Base import DDC.Core.Flow.Process import DDC.Core.Flow.Procedure import DDC.Core.Flow.Compounds import DDC.Core.Flow.Exp import DDC.Core.Flow.Prim import Control.Monad import Data.Maybe -- | Schedule a process kernel into a procedure. -- -- A process kernel is a process with the following restricitions: -- 1) All input series have the same rate. -- 2) A kernel accumulates data into sinks, rather than allocating new values. -- 3) A kernel can be scheduled into a single loop. -- --- The process kernel scheduler can produce code for -- map, reduce, fill, gather, scatter -- -- But not -- fold -- use reduce instead. -- create -- use fill instead. -- pack -- we don't support SIMD masks. -- scheduleKernel :: Lifting -> Process -> Either Error Procedure scheduleKernel lifting (Process { processName = name , processParamTypes = bsParamTypes , processParamValues = bsParamValues , processOperators = operators }) = do -- Check the parameter series all have the same rate. tK <- slurpRateOfParamTypes (map typeOfBind bsParamValues) -- Check the primary rate variable matches the rates of the series. (case bsParamTypes of [] -> Left ErrorNoRateParameters BName n k : _ | k == kRate , TVar (UName n) == tK -> return () _ -> Left ErrorPrimaryRateMismatch) -- Lower rates of series parameters. let bsParamValues_lowered = map (\(BName n t) -> let t' = fromMaybe t $ lowerSeriesRate lifting t in BName n t') $ bsParamValues -- Create the initial loop nest of the process rate. let bsSeries = [ b | b <- bsParamValues , isSeriesType (typeOfBind b) ] -- Body expressions that take the next vec of elements from each -- input series. If the type can't be lifted this will just throw -- a pattern match error. let c = liftingFactor lifting let ssBody = [ BodyStmt (BName (NameVarMod nS "elem") tElem_lifted) (xNextC c tK tElem (XVar (UName nS)) (XVar uIndex)) | BName nS tS <- bsSeries , let Just tElem = elemTypeOfSeriesType tS , let uIndex = UIx 0 , let Just tElem_lifted = liftType lifting tElem ] let nest0 = NestLoop { nestRate = tDown c tK , nestStart = [] , nestBody = ssBody , nestInner = NestEmpty , nestEnd = [] , nestResult = xUnit } nest' <- foldM (scheduleOperator lifting bsParamValues) nest0 operators return $ Procedure { procedureName = name , procedureParamTypes = bsParamTypes , procedureParamValues = bsParamValues_lowered , procedureNest = nest' } ------------------------------------------------------------------------------- -- | Schedule a single series operator into a loop nest. scheduleOperator :: Lifting -> ScalarEnv -> Nest -- ^ The current loop nest. -> Operator -- ^ The operator to schedule. -> Either Error Nest scheduleOperator lifting envScalar nest op -- Map ----------------------------------------- | OpMap{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tK_down = tDown c tK -- Bind for the result element. let Just bResultE = elemBindOfSeriesBind (opResultSeries op) >>= liftTypeOfBind lifting -- Bounds for all the input series. let Just usInput = sequence $ map elemBoundOfSeriesBound $ opInputSeriess op -- Bounds for the worker parameters, along with the lifted versions. let bsParam = opWorkerParams op bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam let envLift = zip bsParam bsParam_lifted xWorker_lifted <- liftWorker lifting envScalar envLift $ opWorkerBody op -- Expression to apply the inputs to the worker. let xBody = foldl (\x (b, p) -> XApp (XLam b x) p) (xWorker_lifted) [(b, XVar u) | b <- bsParam_lifted | u <- usInput ] let Just nest2 = insertBody nest tK_down $ [ BodyStmt bResultE xBody ] return nest2 -- Fill ---------------------------------------- | OpFill{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tK_down = tDown c tK -- Bound for input element. let Just uInput = elemBoundOfSeriesBound $ opInputSeries op -- Write to target vector. let Just nest2 = insertBody nest tK_down $ [ BodyStmt (BNone tUnit) (xWriteVectorC c (opElemType op) (XVar $ opTargetVector op) (XVar $ UIx 0) (XVar $ uInput)) ] -- Bind final unit value. let Just nest3 = insertEnds nest2 tK_down $ [ EndStmt (opResultBind op) xUnit ] return nest3 -- Reduce -------------------------------------- | OpReduce{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tK_down = tDown c tK let tA = typeOfBind $ opWorkerParamElem op -- Evaluate the zero value and initialize the vector accumulator. let UName nRef = opTargetRef op let nAccZero = NameVarMod nRef "zero" let bAccZero = BName nAccZero tA let uAccZero = UName nAccZero let nAccVec = NameVarMod nRef "vec" let uAccVec = UName nAccVec let Just nest2 = insertStarts nest tK_down $ [ StartStmt bAccZero (opZero op) , StartAcc nAccVec (tVec c tA) (xvRep c tA (XVar uAccZero)) ] -- Bound for input element. let Just uInput = elemBoundOfSeriesBound $ opInputSeries op -- Bound for intermediate accumulator value. let nAccVal = NameVarMod nRef "val" let uAccVal = UName nAccVal let bAccVal = BName nAccVal (tVec c tA) -- Lift the worker function. let bsParam = [ opWorkerParamAcc op, opWorkerParamElem op ] bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam let envLift = zip bsParam bsParam_lifted xWorker_lifted <- liftWorker lifting envScalar envLift $ opWorkerBody op -- Read the current accumulator value and update it with the worker. let xBody_lifted x1 x2 = XApp (XApp ( XLam (opWorkerParamAcc op) $ XLam (opWorkerParamElem op) (xWorker_lifted)) x1) x2 let Just nest3 = insertBody nest2 tK_down $ [ BodyAccRead nAccVec (tVec c tA) bAccVal , BodyAccWrite nAccVec (tVec c tA) (xBody_lifted (XVar uAccVal) (XVar uInput)) ] -- Read back the vector accumulator and to a final fold over its parts. let nAccResult = NameVarMod nRef "res" let bAccResult = BName nAccResult (tVec c tA) let uAccResult = UName nAccResult let bPart (i :: Int) = BName (NameVarMod nAccResult (show i)) tA let uPart (i :: Int) = UName (NameVarMod nAccResult (show i)) let nAccInit = NameVarMod nRef "init" let xBody x1 x2 = XApp (XApp ( XLam (opWorkerParamAcc op) $ XLam (opWorkerParamElem op) (opWorkerBody op)) x1) x2 let Just nest4 = insertEnds nest3 tK_down $ [ EndStmt bAccResult (xRead (tVec c tA) (XVar uAccVec)) , EndStmt (BName nAccInit tA) (xRead tA (XVar $ opTargetRef op)) ] ++ [ EndStmt (bPart 0) (xBody (XVar $ UName nAccInit) (xvProj c 0 tA (XVar uAccResult))) ] ++ [ EndStmt (bPart i) (xBody (XVar (uPart (i - 1))) (xvProj c i tA (XVar uAccResult))) | i <- [1.. c - 1]] -- Write final value to destination. let Just nest5 = insertEnds nest4 tK_down $ [ EndStmt (BNone tUnit) (xWrite tA (XVar $ opTargetRef op) (XVar $ uPart (c - 1))) ] -- Bind final unit value. let Just nest6 = insertEnds nest5 tK_down $ [ EndStmt (opResultBind op) xUnit ] return $ nest6 -- Gather -------------------------------------- | OpGather{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tK_down = tDown c tK -- Bind for result element. let Just bResultE = elemBindOfSeriesBind (opResultBind op) >>= liftTypeOfBind lifting -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) -- Read from vector. let Just nest2 = insertBody nest tK_down $ [ BodyStmt bResultE (xvGather c (opElemType op) (XVar $ opSourceVector op) (XVar $ uIndex)) ] return nest2 -- Scatter ------------------------------------- | OpScatter{} <- op = do let c = liftingFactor lifting let tK = opInputRate op let tK_down = tDown c tK -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) -- Bound of source elements. let Just uElem = elemBoundOfSeriesBound (opSourceElems op) -- Read from vector. let Just nest2 = insertBody nest tK_down $ [ BodyStmt (BNone tUnit) (xvScatter c (opElemType op) (XVar $ opTargetVector op) (XVar $ uIndex) (XVar $ uElem)) ] -- Bind final unit value. let Just nest3 = insertEnds nest2 tK_down $ [ EndStmt (opResultBind op) xUnit ] return nest3 -- Unsupported --------------------------------- | otherwise = Left $ ErrorUnsupported op liftTypeOfBindM :: Lifting -> Bind Name -> Either Error (Bind Name) liftTypeOfBindM lifting b = case liftTypeOfBind lifting b of Just b' -> return b' _ -> Left $ ErrorCannotLiftType (typeOfBind b)