{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} -- | Find all Alloc statements and associate their memory blocks with the -- allocation size. module Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizes ( memBlockSizesFunDef, memBlockSizesParamsBodyNonRec , Sizes ) where import qualified Data.Map.Strict as M import Control.Monad.Writer import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory (ExplicitMemorish, ExplicitMemory, InKernel) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Types import Futhark.Optimise.MemoryBlockMerging.Miscellaneous -- | maps memory blocks to its size and space/type type Sizes = M.Map MName (SubExp, Space) -- Also Space information newtype FindM lore a = FindM { unFindM :: Writer Sizes a } deriving (Monad, Functor, Applicative, MonadWriter Sizes) type LoreConstraints lore = (ExplicitMemorish lore, AllocSizeUtils lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM recordMapping :: VName -> (SubExp, Space) -> FindM lore () recordMapping var (size, space) = tell $ M.singleton var (size, space) memBlockSizesFunDef :: LoreConstraints lore => FunDef lore -> Sizes memBlockSizesFunDef fundef = let m = unFindM $ do mapM_ lookInFParam $ funDefParams fundef lookInBody $ funDefBody fundef mem_sizes = execWriter m in mem_sizes memBlockSizesParamsBodyNonRec :: LoreConstraints lore => [FParam lore] -> Body lore -> Sizes memBlockSizesParamsBodyNonRec params body = let m = unFindM $ do mapM_ lookInFParam params mapM_ lookInStm $ bodyStms body mem_sizes = execWriter m in mem_sizes lookInFParam :: LoreConstraints lore => FParam lore -> FindM lore () lookInFParam (Param mem (ExpMem.MemMem size space)) = recordMapping mem (size, space) lookInFParam _ = return () lookInLParam :: LoreConstraints lore => LParam lore -> FindM lore () lookInLParam (Param mem (ExpMem.MemMem size space)) = recordMapping mem (size, space) lookInLParam _ = return () lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds _res) = mapM_ lookInStmRec bnds lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds _res) = mapM_ lookInStmRec bnds lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do case patvalelems of [PatElem mem _] -> case lookForAllocSize e of Just (size, space) -> recordMapping mem (size, space) Nothing -> return () _ -> return () mapM_ lookInPatCtxElem patctxelems lookInStmRec :: LoreConstraints lore => Stm lore -> FindM lore () lookInStmRec stm@(Let _ _ e) = do lookInStm stm fullWalkExpM walker walker_kernel e where walker = identityWalker { walkOnBody = lookInBody , walkOnFParam = lookInFParam , walkOnLParam = lookInLParam } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInLambda , walkOnKernelLParam = lookInLParam } lookInPatCtxElem :: LoreConstraints lore => PatElem lore -> FindM lore () lookInPatCtxElem (PatElem mem (ExpMem.MemMem size space)) = recordMapping mem (size, space) lookInPatCtxElem _ = return () lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore () lookInLambda (Lambda params body _) = do forM_ params lookInLParam lookInBody body class AllocSizeUtils lore where lookForAllocSize :: Exp lore -> Maybe (SubExp, Space) instance AllocSizeUtils ExplicitMemory where lookForAllocSize (Op (ExpMem.Alloc size space)) = Just (size, space) lookForAllocSize _ = Nothing instance AllocSizeUtils InKernel where lookForAllocSize (Op (ExpMem.Alloc size space)) = Just (size, space) lookForAllocSize _ = Nothing