module DDC.Core.Flow.Transform.Schedule.Scalar (scheduleScalar) where import DDC.Core.Flow.Transform.Schedule.Nest import DDC.Core.Flow.Transform.Schedule.Error import DDC.Core.Flow.Transform.Schedule.Base import DDC.Core.Flow.Procedure import DDC.Core.Flow.Process import DDC.Core.Flow.Compounds import DDC.Core.Flow.Prim import DDC.Core.Flow.Prim.OpStore import DDC.Core.Flow.Exp import DDC.Core.Flow.Context -- | Schedule a process into a procedure, producing scalar code. scheduleScalar :: Process -> Either Error Procedure scheduleScalar (Process { processName = name , processParamFlags = bsParams , processContext = context }) = do nest <- scheduleContext (\r _ -> return r) (scheduleOperator context) context return $ Procedure { procedureName = name , procedureParamFlags = bsParams , procedureNest = nest } ------------------------------------------------------------------------------- -- | Schedule a single series operator into a loop nest. scheduleOperator :: Context -- ^ Context of all operators -> FillMap -- ^ Map of which operators use which write-to accs -> Operator -- ^ Operator to schedule. -> Either Error ([StmtStart], [StmtBody], [StmtEnd]) scheduleOperator _ctx fills op -- Id ------------------------------------------- | OpId{} <- op = do -- Get binders for the input elements. let Just bResult = elemBindOfSeriesBind (opResultSeries op) let Just uInput = elemBoundOfSeriesBound (opInputSeries op) return ( [] , [ BodyStmt bResult (XVar uInput) ] , [] ) | OpSeriesOfRateVec{} <- op = do let tK = opInputRate op let tA = opElemType op let bS = opResultSeries op let uInput = opInputRateVec op let Just uS = takeSubstBoundOfBind bS let Just tP = procTypeOfSeriesType (typeOfBind bS) let Just bResult = elemBindOfSeriesBind bS -- Convert the RateVec to a series let starts = [ StartStmt bS (xSeriesOfRateVec tP tK tA (XVar uInput)) ] -- Body expressions that take the next element from each input series. let bodies = [ BodyStmt bResult (xNext tP tK tA (XVar uS) (XVar (UIx 0))) ] return ( starts , bodies , [] ) | OpSeriesOfArgument{} <- op = do let tK = opInputRate op let tA = opElemType op let bS = opResultSeries op let Just uS = takeSubstBoundOfBind bS let Just tP = procTypeOfSeriesType (typeOfBind bS) let Just bResult = elemBindOfSeriesBind bS -- Body expressions that take the next element from each input series. -- Could be different to RateVec above, since could be from other source? let bodies = [ BodyStmt bResult (xNext tP tK tA (XVar uS) (XVar (UIx 0))) ] return ( [] , bodies , [] ) -- Rep ----------------------------------------- | OpRep{} <- op = do -- Make a binder for the replicated element. let BName nResult _ = opResultSeries op let nVal = NameVarMod nResult "val" let uVal = UName nVal let bVal = BName nVal (opElemType op) -- Get the binder for the use of it in the replicated context. let Just bResult = elemBindOfSeriesBind (opResultSeries op) -- Evaluate the expression to be replicated once, -- before the main loop. let starts = [ StartStmt bVal (opInputExp op) ] -- Use the expression for each iteration of the loop. let bodies = [ BodyStmt bResult (XVar uVal) ] return (starts, bodies, []) -- Reps ---------------------------------------- | OpReps{} <- op = do -- Lookup binder for the input element. let Just uInput = elemBoundOfSeriesBound (opInputSeries op) -- Set the result to point to the input element. let Just bResult = elemBindOfSeriesBind (opResultSeries op) let bodies = [ BodyStmt bResult (XVar uInput)] return ([], bodies, []) -- Indices -------------------------------------- | OpIndices{} <- op = do -- In a segment context the variable ^0 is the index into -- the current segment. let Just bResult = elemBindOfSeriesBind (opResultSeries op) let bodies = [ BodyStmt bResult (XVar (UIx 0)) ] return ([], bodies, []) -- Fill ----------------------------------------- | OpFill{} <- op = do -- Get bound of the input element. let Just uInput = elemBoundOfSeriesBound (opInputSeries op) -- Write the current element to the vector. let UName nVec = opTargetVector op let index | Just n <- getAcc fills nVec = xRead tNat $ XVar $ UName $ NameVarMod n "count" | otherwise = XVar $ UIx 0 let bodies = [ BodyVecWrite nVec -- destination vector (opElemType op) -- series elem type index -- index (XVar uInput) ] -- value let inc | Just n <- getAcc fills nVec , n == nVec = [ BodyAccWrite (NameVarMod n "count") tNat (xIncrement index) ] | otherwise = [] return ([], bodies ++ inc, []) -- Gather --------------------------------------- | OpGather{} <- op = do -- Bind for result element. let Just bResult = elemBindOfSeriesBind (opResultBind op) -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) let buf = xBufOfRateVec (opVectorRate op) (opElemType op) (XVar $ opSourceVector op) -- Read from the vector. let bodies = [ BodyStmt bResult (xReadVector (opElemType op) buf (XVar $ uIndex)) ] return ([], bodies, []) -- Scatter -------------------------------------- | OpScatter{} <- op = do -- Bound of source index. let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op) -- Bound of source elements. let Just uElem = elemBoundOfSeriesBound (opSourceElems op) -- Read from vector. let bodies = [ BodyStmt (BNone tUnit) (xWriteVector (opElemType op) (XVar $ bufOfVectorName $ opTargetVector op) (XVar $ uIndex) (XVar $ uElem)) ] -- Bind final unit value. let ends = [ EndStmt (opResultBind op) xUnit ] return ([], bodies, ends) -- Maps ----------------------------------------- | OpMap{} <- op = do -- Bind for the result element. let Just bResult = elemBindOfSeriesBind (opResultSeries op) -- Binds for all the input elements. let Just usInput = sequence $ map elemBoundOfSeriesBound $ opInputSeriess op -- Apply input element vars into the worker body. let xBody = foldl (\x (b, p) -> XApp (XLam b x) p) (opWorkerBody op) [(b, XVar u) | b <- opWorkerParams op | u <- usInput ] let bodies = [ BodyStmt bResult xBody ] return ([], bodies, []) -- Pack ---------------------------------------- | OpPack{} <- op = do -- Lookup binder for the input element. let Just uInput = elemBoundOfSeriesBound (opInputSeries op) -- Set the result to point to the input element let Just bResult = elemBindOfSeriesBind (opResultSeries op) let bodies = [ BodyStmt bResult (XVar uInput)] return ([], bodies, []) -- Generate ------------------------------------- | OpGenerate{} <- op = do -- Bind for the result element. let Just bResult = elemBindOfSeriesBind (opResultSeries op) -- Apply loop index into the worker body. let xBody = XApp ( XLam (opWorkerParamIndex op) (opWorkerBody op)) (XVar (UIx 0)) -- index let bodies = [ BodyStmt bResult xBody ] return ([], bodies, []) -- Reduce -------------------------------------- | OpReduce{} <- op = do -- Initialize the accumulator. let UName nResult = opTargetRef op let nAcc = NameVarMod nResult "acc" let tAcc = typeOfBind (opWorkerParamAcc op) let nAccInit = NameVarMod nResult "init" let starts = [ StartStmt (BName nAccInit tAcc) (xRead tAcc (XVar $ opTargetRef op)) , StartAcc nAcc tAcc (XVar (UName nAccInit)) ] -- Lookup binders for the input elements. let Just uInput = elemBoundOfSeriesBound (opInputSeries op) -- Bind for intermediate accumulator value. let nAccVal = NameVarMod nResult "val" let uAccVal = UName nAccVal let bAccVal = BName nAccVal tAcc -- Substitute input and accumulator vars into worker body. let xBody x1 x2 = XApp (XApp ( XLam (opWorkerParamAcc op) $ XLam (opWorkerParamElem op) (opWorkerBody op)) x1) x2 -- Update the accumulator in the loop body. let bodies = [ BodyAccRead nAcc tAcc bAccVal , BodyAccWrite nAcc tAcc (xBody (XVar uAccVal) (XVar uInput)) ] -- Read back the final value after the loop has finished and -- write it to the destination. let nAccRes = NameVarMod nResult "res" let ends = [ EndAcc nAccRes tAcc nAcc , EndStmt (BNone tUnit) (xWrite tAcc (XVar $ opTargetRef op) (XVar $ UName nAccRes)) ] return (starts, bodies, ends) -- Unsupported ---------------------------------- | otherwise = Left $ ErrorUnsupported op -- | Build an expression that increments a natural. xIncrement :: Exp a Name -> Exp a Name xIncrement xx = xApps (XVar (UPrim (NamePrimArith PrimArithAdd) (typePrimArith PrimArithAdd))) [ XType tNat, xx, XCon (dcNat 1) ]