{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} -- | The simplification engine is only willing to hoist allocations -- out of loops if the memory block resulting from the allocation is -- dead at the end of the loop. If it is not, we may cause data -- hazards. -- -- This module rewrites loops with memory block merge parameters such -- that each memory block is copied at the end of the iteration, thus -- ensuring that any allocation inside the loop is dead at the end of -- the loop. This is only possible for allocations whose size is -- loop-invariant, although the initial size may differ from the size -- produced by the loop result. -- -- Additionally, inside parallel kernels we also copy the initial -- value. This has the effect of making the memory block returned by -- the array non-existential, which is important for later memory -- expansion to work. module Futhark.Optimise.DoubleBuffer ( doubleBuffer ) where import Control.Monad.State import Control.Monad.Writer import Control.Monad.Reader import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Maybe import Data.List import Futhark.MonadFreshNames import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory hiding (Prog, Body, Stm, Pattern, PatElem, BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType) import Futhark.Pass doubleBuffer :: Pass ExplicitMemory ExplicitMemory doubleBuffer = Pass { passName = "Double buffer" , passDescription = "Perform double buffering for merge parameters of sequential loops." , passFunction = intraproceduralTransformation optimiseFunDef } -- This pass is written in a slightly weird way because we want to -- apply essentially the same transformation both outside and inside -- kernel bodies, which are different (but similar) representations. -- Thus, the environment is parametrised by the lore and contains the -- function used to transform 'Op's for the lore. optimiseFunDef :: FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory) optimiseFunDef fundec = modifyNameSource $ \src -> let m = runDoubleBufferM $ inScopeOf fundec $ optimiseBody $ funDefBody fundec (body', src') = runState (runReaderT m env) src in (fundec { funDefBody = body' }, src') where env = Env mempty optimiseKernelOp doNotTouchLoop optimiseKernelOp (Inner k) = do scope <- castScope <$> askScope modifyNameSource $ runState (runReaderT (runDoubleBufferM $ Inner <$> optimiseKernel k) $ Env scope optimiseInKernelOp optimiseLoop) where optimiseKernel = mapKernelM identityKernelMapper { mapOnKernelBody = optimiseBody , mapOnKernelKernelBody = optimiseKernelBody , mapOnKernelLambda = optimiseLambda } optimiseKernelOp op = return op optimiseInKernelOp (Inner (GroupStream w maxchunk lam accs arrs)) = do lam' <- optimiseGroupStreamLambda lam return $ Inner $ GroupStream w maxchunk lam' accs arrs optimiseInKernelOp op = return op doNotTouchLoop ctx val body = return (mempty, ctx, val, body) data Env lore = Env { envScope :: Scope lore , envOptimiseOp :: Op lore -> DoubleBufferM lore (Op lore) , envOptimiseLoop :: OptimiseLoop lore } newtype DoubleBufferM lore a = DoubleBufferM { runDoubleBufferM :: ReaderT (Env lore) (State VNameSource) a } deriving (Functor, Applicative, Monad, MonadReader (Env lore), MonadFreshNames) instance Annotations lore => HasScope lore (DoubleBufferM lore) where askScope = asks envScope instance Annotations lore => LocalScope lore (DoubleBufferM lore) where localScope scope = local $ \env -> env { envScope = envScope env <> scope } -- | Bunch up all the constraints for less typing. type LoreConstraints lore inner = (ExpAttr lore ~ (), BodyAttr lore ~ (), ExplicitMemorish lore, Op lore ~ MemOp inner) optimiseBody :: LoreConstraints lore inner => Body lore -> DoubleBufferM lore (Body lore) optimiseBody body = do bnds' <- optimiseStms $ stmsToList $ bodyStms body return $ body { bodyStms = stmsFromList bnds' } optimiseStms :: LoreConstraints lore inner => [Stm lore] -> DoubleBufferM lore [Stm lore] optimiseStms [] = return [] optimiseStms (e:es) = do e_es <- optimiseStm e es' <- localScope (castScope $ scopeOf e_es) $ optimiseStms es return $ e_es ++ es' optimiseStm :: forall lore inner. LoreConstraints lore inner => Stm lore -> DoubleBufferM lore [Stm lore] optimiseStm (Let pat aux (DoLoop ctx val form body)) = do body' <- localScope (scopeOf form <> scopeOfFParams (map fst $ ctx++val)) $ optimiseBody body opt_loop <- asks envOptimiseLoop (bnds, ctx', val', body'') <- opt_loop ctx val body' return $ bnds ++ [Let pat aux $ DoLoop ctx' val' form body''] optimiseStm (Let pat aux e) = pure . Let pat aux <$> mapExpM optimise e where optimise = identityMapper { mapOnBody = \_ x -> -- This type annotation is -- necessary to prevent the GHC -- 8.4 type checker from going -- nuts. (optimiseBody x :: DoubleBufferM lore (Body lore)) , mapOnOp = optimiseOp } optimiseOp :: Op lore -> DoubleBufferM lore (Op lore) optimiseOp op = do f <- asks envOptimiseOp f op optimiseKernelBody :: KernelBody InKernel -> DoubleBufferM InKernel (KernelBody InKernel) optimiseKernelBody kbody = do stms' <- optimiseStms $ stmsToList $ kernelBodyStms kbody return $ kbody { kernelBodyStms = stmsFromList stms' } optimiseLambda :: Lambda InKernel -> DoubleBufferM InKernel (Lambda InKernel) optimiseLambda lam = do body <- localScope (castScope $ scopeOf lam) $ optimiseBody $ lambdaBody lam return lam { lambdaBody = body } optimiseGroupStreamLambda :: GroupStreamLambda InKernel -> DoubleBufferM InKernel (GroupStreamLambda InKernel) optimiseGroupStreamLambda lam = do body <- localScope (scopeOf lam) $ optimiseBody $ groupStreamLambdaBody lam return lam { groupStreamLambdaBody = body } type OptimiseLoop lore = [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Body lore -> DoubleBufferM lore ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore) optimiseLoop :: LoreConstraints lore inner => OptimiseLoop lore optimiseLoop ctx val body = do -- We start out by figuring out which of the merge variables should -- be double-buffered. buffered <- doubleBufferMergeParams (zip (map fst ctx) (bodyResult body)) (map fst merge) (boundInBody body) -- Then create the allocations of the buffers and copies of the -- initial values. (merge', allocs) <- allocStms merge buffered -- Modify the loop body to copy buffered result arrays. let body' = doubleBufferResult (map fst merge) buffered body (ctx', val') = splitAt (length ctx) merge' -- Modify the initial merge p return (allocs, ctx', val', body') where merge = ctx ++ val -- | The booleans indicate whether we should also play with the -- initial merge values. data DoubleBuffer lore = BufferAlloc VName SubExp Space Bool | BufferCopy VName IxFun VName Bool -- ^ First name is the memory block to copy to, -- second is the name of the array copy. | NoBuffer deriving (Show) doubleBufferMergeParams :: (ExplicitMemorish lore, MonadFreshNames m) => [(FParam lore,SubExp)] -> [FParam lore] -> Names -> m [DoubleBuffer lore] doubleBufferMergeParams ctx_and_res val_params bound_in_loop = evalStateT (mapM buffer val_params) M.empty where loopVariant v = v `S.member` bound_in_loop || v `elem` map (paramName . fst) ctx_and_res loopInvariantSize (Constant v) = Just (Constant v, True) loopInvariantSize (Var v) = case find ((==v) . paramName . fst) ctx_and_res of Just (_, Constant val) -> Just (Constant val, False) Just (_, Var v') | not $ loopVariant v' -> Just (Var v', False) Just _ -> Nothing Nothing -> Just (Var v, True) buffer fparam = case paramType fparam of Mem size space | Just (size', b) <- loopInvariantSize size -> do -- Let us double buffer this! bufname <- lift $ newVName "double_buffer_mem" modify $ M.insert (paramName fparam) (bufname, b) return $ BufferAlloc bufname size' space b Array {} | MemArray _ _ _ (ArrayIn mem ixfun) <- paramAttr fparam -> do buffered <- gets $ M.lookup mem case buffered of Just (bufname, b) -> do copyname <- lift $ newVName "double_buffer_array" return $ BufferCopy bufname ixfun copyname b Nothing -> return NoBuffer _ -> return NoBuffer allocStms :: LoreConstraints lore inner => [(FParam lore,SubExp)] -> [DoubleBuffer lore] -> DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore]) allocStms merge = runWriterT . zipWithM allocation merge where allocation m@(Param pname _, _) (BufferAlloc name size space b) = do tell [Let (Pattern [] [PatElem name $ MemMem size space]) (defAux ()) $ Op $ Alloc size space] if b then return (Param pname $ MemMem size space, Var name) else return m allocation (f, Var v) (BufferCopy mem _ _ b) | b = do v_copy <- lift $ newVName $ baseString v ++ "_double_buffer_copy" (_v_mem, v_ixfun) <- lift $ lookupArraySummary v let bt = elemType $ paramType f shape = arrayShape $ paramType f bound = MemArray bt shape NoUniqueness $ ArrayIn mem v_ixfun tell [Let (Pattern [] [PatElem v_copy bound]) (defAux ()) $ BasicOp $ Copy v] return (f, Var v_copy) allocation (f, se) _ = return (f, se) doubleBufferResult :: (ExplicitMemorish lore, ExpAttr lore ~ (), BodyAttr lore ~ ()) => [FParam lore] -> [DoubleBuffer lore] -> Body lore -> Body lore doubleBufferResult valparams buffered (Body () bnds res) = let (ctx_res, val_res) = splitAt (length res - length valparams) res (copybnds,val_res') = unzip $ zipWith3 buffer valparams buffered val_res in Body () (bnds<>stmsFromList (catMaybes copybnds)) $ ctx_res ++ val_res' where buffer _ (BufferAlloc bufname _ _ _) _ = (Nothing, Var bufname) buffer fparam (BufferCopy bufname ixfun copyname _) (Var v) = -- To construct the copy we will need to figure out its type -- based on the type of the function parameter. let t = resultType $ paramType fparam summary = MemArray (elemType t) (arrayShape t) NoUniqueness $ ArrayIn bufname ixfun copybnd = Let (Pattern [] [PatElem copyname summary]) (defAux ()) $ BasicOp $ Copy v in (Just copybnd, Var copyname) buffer _ _ se = (Nothing, se) parammap = M.fromList $ zip (map paramName valparams) res resultType t = t `setArrayDims` map substitute (arrayDims t) substitute (Var v) | Just replacement <- M.lookup v parammap = replacement substitute se = se