{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} -- | Find out where allocation sizes are used. For each statement, which sizes -- are in scope? module Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizeUses ( findSizeUsesFunDef ) where import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Maybe (mapMaybe) import Control.Monad import Control.Monad.RWS import Control.Monad.Writer import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory ( ExplicitMemory, ExplicitMemorish) import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Miscellaneous import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizes import Futhark.Optimise.MemoryBlockMerging.PrimExps type SizeVars = Names type DeclarationsSoFar = Names -- The final return value. Describes which size variables are in scope at the -- creation of the key size variable. type UsesBefore = M.Map VName Names newtype FindM lore a = FindM { unFindM :: RWS SizeVars UsesBefore DeclarationsSoFar a } deriving (Monad, Functor, Applicative, MonadReader SizeVars, MonadWriter UsesBefore, MonadState DeclarationsSoFar) type LoreConstraints lore = (ExplicitMemorish lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM addDeclarations :: Names -> FindM lore () addDeclarations = modify . S.union addUsesBefore :: VName -> Names -> FindM lore () addUsesBefore var declarations_so_far = tell $ M.singleton var declarations_so_far findSizeUsesFunDef :: FunDef ExplicitMemory -> UsesBefore findSizeUsesFunDef fundef = let size_vars = mapMaybe (subExpVar . fst) $ M.elems $ memBlockSizesFunDef fundef var_to_pe = findPrimExpsFunDef fundef -- We want to find 'uses before' for all size vars *and* which variables -- they depend on. This is a compromise between recording the -- relationship for only size variables and all variables. We need this -- compromise for 'sizesCanBeMaxedKernelArray' in Reuse.Core. find_pe_vars v0 = maybe S.empty (S.insert v0 . execWriter . traverse (\v -> do tell $ S.singleton v tell $ find_pe_vars v return v)) $ M.lookup v0 var_to_pe size_vars' = S.unions $ map find_pe_vars size_vars m = unFindM $ do forM_ (funDefParams fundef) lookInFParam lookInBody $ funDefBody fundef res = snd $ evalRWS m size_vars' S.empty in res lookInFParam :: FParam lore -> FindM lore () lookInFParam (Param x _) = lookAtNewDecls $ S.singleton x lookInLParam :: LParam lore -> FindM lore () lookInLParam (Param x _) = lookAtNewDecls $ S.singleton x lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds _res) = mapM_ lookInStm bnds lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds _res) = mapM_ lookInStm bnds lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm stm@(Let _ _ e) = do let new_decls = S.fromList $ newDeclarationsStm stm lookAtNewDecls new_decls -- Recursive body walk. fullWalkExpM walker walker_kernel e where walker = identityWalker { walkOnBody = lookInBody , walkOnFParam = lookInFParam , walkOnLParam = lookInLParam } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInLambda , walkOnKernelLParam = lookInLParam } lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore () lookInLambda (Lambda params body _) = do forM_ params lookInLParam lookInBody body lookAtNewDecls :: Names -> FindM lore () lookAtNewDecls new_decls = do all_size_vars <- ask declarations_so_far <- get let new_size_vars = S.intersection all_size_vars new_decls forM_ new_size_vars $ \var -> addUsesBefore var declarations_so_far addDeclarations new_size_vars