{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} -- | Find safety condition 5 for all statements. module Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition5 ( findSafetyCondition5FunDef ) where import qualified Data.Map.Strict as M import qualified Data.Set as S import Control.Monad import Control.Monad.RWS import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory ( InKernel, ExplicitMemory, ExplicitMemorish) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Types import Futhark.Optimise.MemoryBlockMerging.Miscellaneous type DeclarationsSoFar = Names type VarsInUseBeforeMem = M.Map MName Names newtype FindM lore a = FindM { unFindM :: RWS FirstUses VarsInUseBeforeMem DeclarationsSoFar a } deriving (Monad, Functor, Applicative, MonadReader FirstUses, MonadWriter VarsInUseBeforeMem, MonadState DeclarationsSoFar) type LoreConstraints lore = (ExplicitMemorish lore, ExtractKernelDefVars lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM findSafetyCondition5FunDef :: FunDef ExplicitMemory -> FirstUses -> VarsInUseBeforeMem findSafetyCondition5FunDef fundef first_uses = let m = unFindM $ do forM_ (funDefParams fundef) lookInFParam lookInBody $ funDefBody fundef res = snd $ evalRWS m first_uses S.empty in res lookInFParam :: FParam lore -> FindM lore () lookInFParam (Param x _) = modify $ S.insert x lookInLParam :: LParam lore -> FindM lore () lookInLParam (Param x _) = modify $ S.insert x lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds _res) = mapM_ lookInStm bnds lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds _res) = mapM_ lookInStm bnds lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm stm@(Let _ _ e) = do let new_decls = newDeclarationsStm stm first_uses <- ask declarations_so_far <- get forM_ (S.toList $ S.unions $ map (`lookupEmptyable` first_uses) new_decls) $ \mem -> tell $ M.singleton mem declarations_so_far forM_ new_decls $ \x -> modify $ S.insert x -- Special loop handling: Extract useful variables that are in use. case e of DoLoop _ _ loopform _ -> case loopform of ForLoop i _ _ _ -> modify $ S.insert i WhileLoop c -> modify $ S.insert c _ -> return () modify $ S.union (extractKernelDefVars e) -- Recursive body walk. fullWalkExpM walker walker_kernel e where walker = identityWalker { walkOnBody = lookInBody , walkOnFParam = lookInFParam , walkOnLParam = lookInLParam } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInLambda , walkOnKernelLParam = lookInLParam } lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore () lookInLambda (Lambda params body _) = do forM_ params lookInLParam lookInBody body class ExtractKernelDefVars lore where -- Extract variables from a kernel definition. extractKernelDefVars :: Exp lore -> Names instance ExtractKernelDefVars ExplicitMemory where extractKernelDefVars (Op (ExpMem.Inner (Kernel _ kernelspace _ _))) = S.fromList $ map ($ kernelspace) [spaceGlobalId, spaceLocalId, spaceGroupId] extractKernelDefVars _ = S.empty instance ExtractKernelDefVars InKernel where extractKernelDefVars _ = S.empty