{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE LambdaCase #-} -- | Find all variables in a statement. module Futhark.Optimise.MemoryBlockMerging.AllExpVars ( findAllExpVars ) where import qualified Data.Set as S import Control.Monad import Control.Monad.Writer import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory (ExplicitMemorish) import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Miscellaneous newtype FindM lore a = FindM { unFindM :: Writer Names a } deriving (Monad, Functor, Applicative, MonadWriter Names) type LoreConstraints lore = (ExplicitMemorish lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM -- Find all the variables (both free and bound) that occur in a statement and -- any nested bodies. We use this to record which extra variables need to have -- their memory blocks updated when some variable needs updating. The result -- might be an empty set, but in the case of If, DoLoop, and kernels, the result -- might be nonempty. We cannot just find all variables in the program and look -- through them every time we need to, since a memory block can (at least in -- theory) be present in two different places (which also means by two different -- variable sets) in a program, so we should limit ourselves to looking in the -- statement declaring a new current use of the memory. findAllExpVars :: LoreConstraints lore => Exp lore -> Names findAllExpVars e = let m = unFindM $ lookInExp e in execWriter m lookInExp :: LoreConstraints lore => Exp lore -> FindM lore () lookInExp = fullWalkExpM walker walker_kernel where walker = identityWalker { walkOnBody = lookInBody , walkOnFParam = lookInFParam , walkOnLParam = lookInLParam } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInLambda , walkOnKernelLParam = lookInLParam } lookInFParam :: FParam lore -> FindM lore () lookInFParam (Param x _) = tell $ S.singleton x lookInLParam :: LParam lore -> FindM lore () lookInLParam (Param x _) = tell $ S.singleton x lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds _res) = mapM_ lookInStm bnds lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds res) = do mapM_ lookInStm bnds forM_ res $ \case ThreadsReturn{} -> return () WriteReturn _ arr _ -> tell $ S.singleton arr ConcatReturns{} -> return () KernelInPlaceReturn v -> tell $ S.singleton v lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm (Let (Pattern _ patvalelems) _ e) = do forM_ patvalelems $ \(PatElem x _) -> tell $ S.singleton x lookInExp e lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore () lookInLambda (Lambda params body _) = do forM_ params lookInLParam lookInBody body