{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} -- | Find safety condition 2 for all statements. module Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition2 ( findSafetyCondition2FunDef ) where import qualified Data.Map.Strict as M import qualified Data.Set as S import Control.Monad import Control.Monad.RWS import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory ( ExplicitMemory, InKernel, ExplicitMemorish) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Types import Futhark.Optimise.MemoryBlockMerging.Miscellaneous type CurrentAllocatedBlocks = MNames type AllocatedBlocksBeforeCreation = M.Map VName MNames newtype FindM lore a = FindM { unFindM :: RWS () AllocatedBlocksBeforeCreation CurrentAllocatedBlocks a } deriving (Monad, Functor, Applicative, MonadWriter AllocatedBlocksBeforeCreation, MonadState CurrentAllocatedBlocks) type LoreConstraints lore = (ExplicitMemorish lore, IsAlloc lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM findSafetyCondition2FunDef :: FunDef ExplicitMemory -> AllocatedBlocksBeforeCreation findSafetyCondition2FunDef fundef = let m = unFindM $ do forM_ (funDefParams fundef) lookInFParam lookInBody $ funDefBody fundef res = snd $ evalRWS m () S.empty in res lookInFParam :: FParam ExplicitMemory -> FindM lore () lookInFParam (Param _ membound) = -- Unique array function parameters also count as "allocations" in which -- memory can be coalesced. case membound of ExpMem.MemArray _ _ Unique (ExpMem.ArrayIn mem _) -> modify $ S.insert mem _ -> return () lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds _res) = mapM_ lookInStm bnds lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds _res) = mapM_ lookInStm bnds lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do let new_decls0 = map patElemName (patctxelems ++ patvalelems) new_decls1 = case e of DoLoop _mergectxparams mergevalparams _loopform _body -> -- Technically not a declaration for the current expression, but very -- close, and hopefully okay to consider it as one. map (paramName . fst) mergevalparams _ -> [] new_decls = new_decls0 ++ new_decls1 cur_allocated_blocks <- get forM_ new_decls $ \x -> tell $ M.singleton x cur_allocated_blocks case patvalelems of [PatElem mem _] -> when (isAlloc e) $ modify $ S.insert mem _ -> return () -- RECURSIVE BODY WALK. fullWalkExpM walker walker_kernel e where walker = identityWalker { walkOnBody = lookInBody , walkOnFParam = lookInFParam } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInBody . lambdaBody } class IsAlloc lore where isAlloc :: Exp lore -> Bool instance IsAlloc ExplicitMemory where isAlloc (Op ExpMem.Alloc{}) = True isAlloc _ = False instance IsAlloc InKernel where isAlloc _ = False