module DDC.Core.Flow.Lower ( Config (..) , defaultConfigScalar , defaultConfigKernel , defaultConfigVector , Method (..) , Lifting (..) , lowerModule) where import DDC.Core.Flow.Transform.Slurp import DDC.Core.Flow.Transform.Schedule import DDC.Core.Flow.Transform.Schedule.Base import DDC.Core.Flow.Transform.Extract import DDC.Core.Flow.Process import DDC.Core.Flow.Procedure import DDC.Core.Flow.Compounds import DDC.Core.Flow.Profile import DDC.Core.Flow.Prim import DDC.Core.Flow.Exp import DDC.Core.Module import DDC.Core.Transform.Annotate import qualified DDC.Core.Simplifier as C import qualified DDC.Core.Simplifier.Recipe as C import qualified DDC.Core.Transform.Namify as C import qualified DDC.Core.Transform.Snip as Snip import qualified DDC.Type.Env as Env import qualified Control.Monad.State.Strict as S import qualified Data.Monoid as M import Control.Monad -- | Configuration for the lower transform. data Config = Config { configMethod :: Method } deriving (Eq, Show) -- | What lowering method to use. data Method -- | Produce sequential scalar code with nested loops. = MethodScalar -- | Produce vector kernel code that only processes an even multiple -- of the vector width. | MethodKernel { methodLifting :: Lifting } -- | Try to produce sequential vector code, -- falling back to scalar code if this is not possible. | MethodVector { methodLifting :: Lifting } deriving (Eq, Show) -- | Config for producing code with just scalar operations. defaultConfigScalar :: Config defaultConfigScalar = Config { configMethod = MethodScalar } -- | Config for producing code with vector operations, -- where the loops just handle a size of data which is an even multiple -- of the vector width. defaultConfigKernel :: Config defaultConfigKernel = Config { configMethod = MethodKernel (Lifting 8)} -- | Config for producing code with vector operations, -- where the loops handle arbitrary data sizes, of any number of elements. defaultConfigVector :: Config defaultConfigVector = Config { configMethod = MethodVector (Lifting 8)} -- Lower ---------------------------------------------------------------------- -- | Take a module that contains only well formed series processes defined -- at top-level, and lower them all into procedures. lowerModule :: Config -> ModuleF -> Either Error ModuleF lowerModule config mm = case slurpProcesses mm of -- Can't slurp a process definition from one of the top level series -- processes. Left err -> Left (ErrorSlurpError err) -- We've got a process definition for all of then. Right procs -> do -- Schedule the processeses into procedures. lets <- mapM (lowerProcess config) procs -- Wrap all the procedures into a new module. let mm_lowered = mm { moduleBody = annotate () $ XLet (LRec lets) xUnit } -- Clean up extracted code let mm_clean = cleanModule mm_lowered return mm_clean -- | Lower a single series process into fused code. lowerProcess :: Config -> Process -> Either Error (BindF, ExpF) lowerProcess config process -- Scalar lowering ------------------------------ | MethodScalar <- configMethod config = do -- Schedule process into scalar code. let Right proc = scheduleScalar process -- Extract code for the kernel let (bProc, xProc) = extractProcedure proc return (bProc, xProc) -- Vector lowering ----------------------------- -- To use the vector method, -- the type of the source function needs to have a quantifier for -- the rate variable (k), as well as a (RateNat k) witness. -- | MethodVector lifting <- configMethod config , [nRN] <- [ nRN | BName nRN tRN <- processParamValues process , isRateNatType tRN ] , bK : _ <- processParamTypes process = do let c = liftingFactor lifting -- Get the primary rate variable. let Just uK = takeSubstBoundOfBind bK let tK = TVar uK -- The RateNat witness let xRN = XVar (UName nRN) ----------------------------------------- -- Create the vector version of the kernel. -- Vector code processes several elements per loop iteration. procVec <- scheduleKernel lifting process let (_, xProcVec) = extractProcedure procVec let bxsDownSeries = [ ( bS , ( BName (NameVarMod n "down") (tSeries (tDown c tK) tE) , xDown c tK tE (XVar (UIx 0)) xS)) | bS@(BName n tS) <- processParamValues process , let Just tE = elemTypeOfSeriesType tS , let Just uS = takeSubstBoundOfBind bS , let xS = XVar uS , isSeriesType tS ] -- Get a value arg to give to the vector procedure. let getDownValArg b | Just (b', _) <- lookup b bxsDownSeries = liftM XVar $ takeSubstBoundOfBind b' | otherwise = liftM XVar $ takeSubstBoundOfBind b let Just xsVecValArgs = sequence $ map getDownValArg (processParamValues process) let bRateDown = BAnon (tRateNat (tDown c (TVar uK))) let xProcVec' = XLam bRateDown $ xLets [LLet b x | (_, (b, x)) <- bxsDownSeries] $ xApps (XApp xProcVec (XType (TVar uK))) $ xsVecValArgs ----------------------------------------- -- Create tail version. -- Scalar code processes the final elements of the loop. procTail <- scheduleScalar process let (bProcTail, xProcTail) = extractProcedure procTail -- Window the input series to select the tails. let bxsTailSeries = [ ( bS, ( BName (NameVarMod n "tail") (tSeries (tTail c tK) tE) , xTail c tK tE (XVar (UIx 0)) xS)) | bS@(BName n tS) <- processParamValues process , let Just tE = elemTypeOfSeriesType tS , let Just uS = takeSubstBoundOfBind bS , let xS = XVar uS , isSeriesType tS ] -- Window the output vectors to select the tails. let bxsTailVector = [ ( bV, ( BName (NameVarMod n "tail") (tVector tE) , xTailVector c tK tE (XVar (UIx 0)) xV)) | bV@(BName n tV) <- processParamValues process , let Just tE = elemTypeOfVectorType tV , let Just uV = takeSubstBoundOfBind bV , let xV = XVar uV , isVectorType tV ] -- Get a value arg to give to the scalar procedure. let getTailValArg b | Just (b', _) <- lookup b bxsTailSeries = liftM XVar $ takeSubstBoundOfBind b' | Just (b', _) <- lookup b bxsTailVector = liftM XVar $ takeSubstBoundOfBind b' | otherwise = liftM XVar $ takeSubstBoundOfBind b let Just xsTailValArgs = sequence $ map getTailValArg (procedureParamValues procTail) let bRateTail = BAnon (tRateNat (tTail c (TVar uK))) let xProcTail' = XLam bRateTail $ xLets [LLet b x | (_, (b, x)) <- bxsTailSeries] $ xLets [LLet b x | (_, (b, x)) <- bxsTailVector] $ xApps (XApp xProcTail (XType (tTail c (TVar uK)))) $ xsTailValArgs ------------------------------------------ -- Stich the vector and scalar versions together. let xProc = foldr XLAM (foldr XLam xBody (processParamValues process)) (processParamTypes process) xBody = XLet (LLet (BNone tUnit) (xSplit c (TVar uK) xRN xProcVec' xProcTail')) xUnit -- Reconstruct a binder for the whole procedure / process. let bProc = BName (processName process) (typeOfBind bProcTail) return (bProc, xProc) -- Kernel lowering ----------------------------- | MethodKernel lifting <- configMethod config = do -- Schedule process into proc <- scheduleKernel lifting process -- Extract code for the kernel let (bProc, xProc) = extractProcedure proc return (bProc, xProc) | otherwise = error $ "ddc-core-flow.lowerProcess: invalid lowering method" -- Clean ---------------------------------------------------------------------- -- | Do some beta-reductions to ensure that arguments to worker functions -- are inlined, then normalize nested applications. -- When snipping, leave lambda abstractions in place so the worker functions -- applied to our loop combinators aren't moved. cleanModule :: ModuleF -> ModuleF cleanModule mm = let clean = C.Trans (C.Namify (C.makeNamifier freshT) (C.makeNamifier freshX)) M.<> C.Trans C.Forward M.<> C.beta M.<> C.Trans (C.Snip (Snip.configZero { Snip.configPreserveLambdas = True })) M.<> C.Trans C.Flatten mm_cleaned = C.result $ S.evalState (C.applySimplifier profile Env.empty Env.empty (C.Fix 4 clean) mm) 0 in mm_cleaned