{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE LambdaCase #-} -- | Find memory block interferences. Maps a memory block to its interference -- set. module Futhark.Optimise.MemoryBlockMerging.Liveness.Interference ( findInterferences ) where import qualified Data.Set as S import qualified Data.Map.Strict as M import qualified Data.List as L import Data.Maybe (mapMaybe, fromMaybe, catMaybes) import Control.Monad import Control.Monad.RWS import Control.Monad.Writer import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory ( ExplicitMemorish, ExplicitMemory, InKernel) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Miscellaneous import Futhark.Optimise.MemoryBlockMerging.Types data Context = Context { ctxVarToMem :: VarMemMappings MemorySrc , ctxMemAliases :: MemAliases , ctxFirstUses :: FirstUses , ctxLastUses :: LastUses , ctxExistentials :: Names , ctxLoopCorrespondingVar :: M.Map VName (VName, SubExp) } deriving (Show) type InterferencesList = [(MName, MNames)] getInterferencesMap :: InterferencesList -> Interferences getInterferencesMap = M.unionsWith S.union . map (uncurry M.singleton) data Current = Current { curAlive :: MNames , curResPotentialKernelInterferences :: PotentialKernelDataRaceInterferences } deriving (Show) newtype FindM lore a = FindM { unFindM :: RWS Context InterferencesList Current a } deriving (Monad, Functor, Applicative, MonadReader Context, MonadWriter InterferencesList, MonadState Current) type LoreConstraints lore = (ExplicitMemorish lore, KernelInterferences lore, SpecialBodyExceptions lore, FullWalk lore) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM awaken :: MName -> FindM lore () awaken mem = modifyCurAlive $ S.insert mem kill :: MName -> FindM lore () kill mem = modifyCurAlive $ S.delete mem modifyCurAlive :: (MNames -> MNames) -> FindM lore () modifyCurAlive f = modify $ \c -> c { curAlive = f $ curAlive c } addPotentialKernelInterferenceGroup :: PotentialKernelDataRaceInterferenceGroup -> FindM lore () addPotentialKernelInterferenceGroup set = modify $ \c -> c { curResPotentialKernelInterferences = curResPotentialKernelInterferences c ++ [set] } recordCurrentInterferences :: FindM lore () recordCurrentInterferences = do current <- gets curAlive -- Interferences are commutative. Reflect that in the resulting data. forM_ (S.toList current) $ \mem -> tell [(mem, current)] recordNewInterferences :: MNames -> FindM lore () recordNewInterferences mems_in_stm = do current <- gets curAlive -- Interferences are commutative. Reflect that in the resulting data. forM_ (S.toList current) $ \mem -> tell [(mem, mems_in_stm)] forM_ (S.toList mems_in_stm) $ \mem -> tell [(mem, current)] -- | Find all memory block interferences in a function definition. findInterferences :: VarMemMappings MemorySrc -> MemAliases -> FirstUses -> LastUses -> Names -> FunDef ExplicitMemory -> (Interferences, PotentialKernelDataRaceInterferences) findInterferences var_to_mem mem_aliases first_uses last_uses existentials fundef = let context = Context { ctxVarToMem = var_to_mem , ctxMemAliases = mem_aliases , ctxFirstUses = first_uses , ctxLastUses = last_uses , ctxExistentials = existentials , ctxLoopCorrespondingVar = M.empty } m = unFindM $ do forM_ (funDefParams fundef) lookInFunDefFParam lookInBody $ funDefBody fundef (cur, interferences_list) = execRWS m context (Current S.empty []) interferences = removeEmptyMaps $ removeKeyFromMapElems $ makeCommutativeMap $ getInterferencesMap interferences_list potential_kernel_interferences = curResPotentialKernelInterferences cur in (interferences, potential_kernel_interferences) lookInFunDefFParam :: FParam lore -> FindM lore () lookInFunDefFParam (Param var _) = do first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses mapM_ awaken $ S.toList first_uses_var recordCurrentInterferences lookInBody :: LoreConstraints lore => Body lore -> FindM lore () lookInBody (Body _ bnds res) = do mapM_ lookInStm bnds lookInRes res lookInKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore () lookInKernelBody (KernelBody _ bnds res) = do mapM_ lookInStm bnds lookInRes $ map kernelResultSubExp res awakenFirstUses :: [PatElem lore] -> FindM lore () awakenFirstUses patvalelems = forM_ patvalelems $ \(PatElem var _) -> do first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses mapM_ awaken $ S.toList first_uses_var isNoOp :: Exp lore -> Bool isNoOp (BasicOp bop) = case bop of Scratch{} -> True _ -> False isNoOp _ = False lookInStm :: LoreConstraints lore => Stm lore -> FindM lore () lookInStm stm@(Let (Pattern _patctxelems patvalelems) _ e) | isNoOp e = awakenFirstUses patvalelems -- There is no reason to record interferences if the current statement will -- not generate any code in the end. We have this check to use the result -- index sharing analysis on loop bodies and not get bogged down by the -- result of a Scratch statement hanging around. | otherwise = do awakenFirstUses patvalelems ctx <- ask let ctx' = ctx { ctxLoopCorrespondingVar = M.union (ctxLoopCorrespondingVar ctx) (findLoopCorrespondingVar ctx stm) } let stm_exceptions = fromMaybe [] $ do indices <- specialBodyIndices e let walker_exc = identityWalker { walkOnBody = \body -> let (body', lcv) = innermostLoopNestBody ctx body ctx'' = ctx' { ctxLoopCorrespondingVar = M.union (ctxLoopCorrespondingVar ctx') lcv } in tell $ interferenceExceptions ctx'' (bodyStms body') (bodyResult body') indices Nothing } walker_kernel_exc = identityKernelWalker { walkOnKernelBody = \body -> let (body', lcv) = innermostLoopNestBody ctx body ctx'' = ctx' { ctxLoopCorrespondingVar = M.union (ctxLoopCorrespondingVar ctx') lcv } in tell $ interferenceExceptions ctx'' (bodyStms body') (bodyResult body') indices Nothing , walkOnKernelKernelBody = \kbody -> tell $ interferenceExceptions ctx' (kernelBodyStms kbody) (mapMaybe (\case ThreadsReturn _ se -> Just se _ -> Nothing) $ kernelBodyResult kbody) indices (specialBodyWriteMems stm) } return $ execWriter $ fullWalkExpM walker_exc walker_kernel_exc e first_uses <- asks ctxFirstUses last_uses <- asks ctxLastUses let stm_mems = S.unions $ map (\pelem -> let v = patElemName pelem in S.union (lookupEmptyable v first_uses) (lookupEmptyable (FromStm v) last_uses)) patvalelems ((), stm_interferences) <- censor (const []) $ listen $ do recordNewInterferences stm_mems local (const ctx') $ fullWalkExpM walker walker_kernel e let stm_interferences' = map (\(k, vs) -> (k, S.fromList $ filter (\v -> not ((k, v) `L.elem` stm_exceptions || (v, k) `L.elem` stm_exceptions)) $ S.toList vs)) stm_interferences tell stm_interferences' potential_kernel_interferences <- findKernelDataRaceInterferences e forM_ potential_kernel_interferences addPotentialKernelInterferenceGroup forM_ patvalelems $ \(PatElem var _) -> do last_uses_var <- lookupEmptyable (FromStm var) <$> asks ctxLastUses mapM_ kill last_uses_var where walker = identityWalker { walkOnBody = lookInBody } walker_kernel = identityKernelWalker { walkOnKernelBody = coerce . lookInBody , walkOnKernelKernelBody = coerce . lookInKernelBody , walkOnKernelLambda = coerce . lookInBody . lambdaBody } -- For perfectly nested loops. Make it possible to find the index function for -- the outer loop. findLoopCorrespondingVar :: LoreConstraints lore => Context -> Stm lore -> M.Map VName (VName, SubExp) findLoopCorrespondingVar ctx (Let (Pattern _patctxelems patvalelems) _ (DoLoop _ _ _ (Body _ stms res))) = M.fromList $ catMaybes $ zipWith findIt patvalelems res where findIt (PatElem pat_v (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn pat_mem _))) (Var res_v) | not (null stms) = case L.last $ stmsToList stms of -- This is how the program looks after coalescing. Let (Pattern _ [PatElem _last_v (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn last_stm_mem _))]) _ (BasicOp (Update _ (DimFix slice_part : _) (Var copy_v))) -> if pat_mem == last_stm_mem then let res_v' = if (memSrcName <$> M.lookup copy_v (ctxVarToMem ctx)) == Just last_stm_mem then Just copy_v else Just res_v in res_v' >>= \t -> Just (t, (pat_v, slice_part)) -- Fix this mess. else Nothing _ -> Nothing | otherwise = Nothing findIt _ _ = Nothing findLoopCorrespondingVar _ _ = M.empty innermostLoopNestBody :: LoreConstraints lore => Context -> Body lore -> (Body lore, M.Map VName (VName, SubExp)) innermostLoopNestBody ctx body = case stmsToList $ bodyStms body of -- This checks for how perfect nested loops looks like after coalescing. This -- is very brittle. If it detects such a nesting, it will ask the -- interference exception algorithm to look in the innermost body. Let _ _ (BasicOp Scratch{}) : loopstm@(Let _ _ (DoLoop _ _ _ body')) : _ -> let (body'', loop_corresponding_var) = innermostLoopNestBody ctx body' in (body'', M.union (findLoopCorrespondingVar ctx loopstm) loop_corresponding_var) _ -> (body, M.empty) lookInRes :: [SubExp] -> FindM lore () lookInRes ses = do let vs = subExpVars ses last_uses <- asks ctxLastUses let last_uses_v = S.unions $ map (\v -> lookupEmptyable (FromRes v) last_uses) vs recordNewInterferences last_uses_v mapM_ kill $ S.toList last_uses_v firstUsesInStm :: LoreConstraints lore => FirstUses -> Stm lore -> [KernelFirstUse] firstUsesInStm first_uses stm = let m = lookFUInStm stm in snd $ evalRWS m first_uses () firstUsesInExp :: LoreConstraints lore => Exp lore -> FindM lore [KernelFirstUse] firstUsesInExp e = do let m = lookFUInExp e first_uses <- asks ctxFirstUses return $ snd $ evalRWS m first_uses () lookFUInStm :: LoreConstraints lore => Stm lore -> RWS FirstUses [KernelFirstUse] () () lookFUInStm (Let (Pattern _patctxelems patvalelems) _ e_stm) = do forM_ patvalelems $ \(PatElem patname membound) -> case membound of ExpMem.MemArray pt _ _ (ExpMem.ArrayIn _ ixfun) -> do fus <- lookupEmptyable patname <$> ask forM_ fus $ \fu -> tell [(fu, patname, pt, ixfun)] _ -> return () lookFUInExp e_stm lookFUInExp :: LoreConstraints lore => Exp lore -> RWS FirstUses [KernelFirstUse] () () lookFUInExp = fullWalkExpM fu_walker fu_walker_kernel where fu_walker = identityWalker { walkOnBody = mapM_ lookFUInStm . bodyStms } fu_walker_kernel = identityKernelWalker { walkOnKernelBody = mapM_ lookFUInStm . bodyStms , walkOnKernelKernelBody = mapM_ lookFUInStm . kernelBodyStms , walkOnKernelLambda = mapM_ lookFUInStm . bodyStms . lambdaBody } class KernelInterferences lore where findKernelDataRaceInterferences :: Exp lore -> FindM lore (Maybe PotentialKernelDataRaceInterferenceGroup) instance KernelInterferences ExplicitMemory where findKernelDataRaceInterferences e = case e of Op (ExpMem.Inner Kernel{}) -> Just <$> firstUsesInExp e _ -> return Nothing instance KernelInterferences InKernel where findKernelDataRaceInterferences _ = return Nothing -- Base info for kernel bodies. class SpecialBodyExceptions lore where specialBodyIndices :: Exp lore -> Maybe [MName] specialBodyWriteMems :: Stm lore -> Maybe [(MName, ExpMem.IxFun, PrimType)] instance SpecialBodyExceptions ExplicitMemory where specialBodyIndices (Op (ExpMem.Inner (Kernel _ kernelspace _ _))) = Just $ map fst $ spaceDimensions kernelspace specialBodyIndices e = specialBodyIndicesBase e specialBodyWriteMems (Let (Pattern _patctxelems patvalelems) _ (Op (ExpMem.Inner Kernel{}))) = Just $ mapMaybe (\p -> case patElemAttr p of ExpMem.MemArray t _ _ (ExpMem.ArrayIn mem ixfun) -> Just (mem, ixfun, t) _ -> Nothing) patvalelems specialBodyWriteMems _ = Nothing instance SpecialBodyExceptions InKernel where specialBodyIndices = specialBodyIndicesBase specialBodyWriteMems = const Nothing specialBodyIndicesBase :: Exp lore -> Maybe [MName] specialBodyIndicesBase (DoLoop _ _ (ForLoop i _ _ _) _) = Just [i] specialBodyIndicesBase _ = Nothing -- Use first use analysis and last use analysis to find any exceptions to the -- naive interference recorded for a statement. interferenceExceptions :: LoreConstraints lore => Context -> Stms lore -> [SubExp] -> [MName] -> Maybe [(MName, ExpMem.IxFun, PrimType)] -> [(MName, MName)] interferenceExceptions ctx stms res indices output_mems_may = let output_vars = subExpVars res indices_slice = map (DimFix . Var) indices stms_first_uses = map (\(mem, _, _, _) -> mem) $ concatMap (firstUsesInStm (ctxFirstUses ctx)) stms results = concat $ flip map (stmsToList stms) $ \(Let (Pattern _patctxelems patvalelems) _ e) -> flip map patvalelems $ \(PatElem v membound) -> let fromread = case e of BasicOp (Index orig slice) -> do orig_mem <- M.lookup orig $ ctxVarToMem ctx if -- These two extra requirements might be superfluous. memSrcName orig_mem `L.notElem` stms_first_uses && not (memSrcName orig_mem `S.member` ctxExistentials ctx) then return (v, typeOf membound, orig_mem, slice) else Nothing _ -> Nothing fromwrite = case e of BasicOp Update{} | ExpMem.MemArray pt _ _ _ <- membound -> do -- The coalescing pass can have created a program where some -- dependencies are a bit indirect. We find the core index function. let (orig', slice') = fixpointIterateMay (\(v0, ss0) -> do (v1, s1) <- M.lookup v0 (ctxLoopCorrespondingVar ctx) return (v1, DimFix s1 : ss0)) (v, []) orig_mem <- M.lookup orig' $ ctxVarToMem ctx if -- These two extra requirements might be superfluous. memSrcName orig_mem `L.notElem` stms_first_uses && not (memSrcName orig_mem `S.member` ctxExistentials ctx) then return (v, Prim pt, orig_mem, slice') else Nothing _ -> Nothing in (fromread, fromwrite) fromreads = mapMaybe fst results fromwrites = mapMaybe snd results fromwrites' = filter (\(v, _, _, _) -> v `L.elem` output_vars) fromwrites fus_input_vars = M.fromList $ map (\(v, _, mem, _) -> (v, S.singleton $ memSrcName mem)) fromreads lus_input_vars = mapFromListSetUnion $ mapMaybe (\(v, typ, mem, _) -> let check e_pat = let frees = freeInExp e_pat -- We need to handle scalars and arrays differently: A last -- use of a scalar variable is the definite last use of the -- memory it represents, while the last use of an array can -- be distorted by reshapes and other aliasing operations, -- so in that case we need to find the last use of the -- memory block. b = case typ of Prim _ -> v `S.member` frees _ -> memSrcName mem `L.elem` mapMaybe ((memSrcName <$>) . (`M.lookup` ctxVarToMem ctx)) (S.toList frees) in b check' (Let _ _ e) = check e in (\stm -> (FromStm $ patElemName $ head $ patternValueElements $ stmPattern stm, S.singleton $ memSrcName mem)) <$> L.find check' (reverse $ stmsToList stms)) fromreads -- 'Just' if in kernel, 'Nothing' otherwise. fus_output_vars = mapFromListSetUnion $ case output_mems_may of Just _ -> [] _ -> map (\(v, _, mem, _) -> (v, S.singleton $ memSrcName mem)) fromwrites' fus_result = mapFromListSetUnion $ case output_mems_may of Just mems -> zip output_vars $ map (S.singleton . (\(mem, _, _) -> mem)) mems _ -> [] -- Extended first uses and last uses. fus = M.unionsWith S.union [ctxFirstUses ctx, fus_input_vars, fus_output_vars] lus = M.unionsWith S.union [ctxLastUses ctx, lus_input_vars] -- Memory-to-slice mappings. input_mem_slices = M.fromList $ map (\(_, _, mem, slice) -> (memSrcName mem, slice)) fromreads output_mem_slices = M.fromList $ case output_mems_may of Just mems -> map (\(mem, _, _) -> (mem, indices_slice)) mems _ -> map (\(_, _, mem, slice) -> (memSrcName mem, slice)) fromwrites' mem_slices = M.union input_mem_slices output_mem_slices -- Memory-to-ixfun mappings. input_mem_ixfuns = M.fromList $ map (\(_, _, mem, _) -> (memSrcName mem, memSrcIxFun mem)) fromreads output_mem_ixfuns = M.fromList $ case output_mems_may of Just mems -> map (\(mem, ixfun, _) -> (mem, ixfun)) mems _ -> map (\(_, _, mem, _) -> (memSrcName mem, memSrcIxFun mem)) fromwrites' mem_ixfuns = M.union input_mem_ixfuns output_mem_ixfuns -- Memory-to-primtype-size mappings. input_mem_primtypes = M.fromList $ map (\(_, t, mem, _) -> (memSrcName mem, elemType t)) fromreads output_mem_primtypes = M.fromList $ case output_mems_may of Just mems -> map (\(mem, _, pt) -> (mem, pt)) mems _ -> map (\(_, t, mem, _) -> (memSrcName mem, elemType t)) fromwrites' mem_primtypes = M.union input_mem_primtypes output_mem_primtypes -- Separation of input memory blocks and output memory blocks. mem_ins0 = S.fromList $ map (\(_, _, mem, _) -> memSrcName mem) fromreads mem_outs0 = S.fromList $ case output_mems_may of Just mems -> map (\(mem, _, _) -> mem) mems _ -> map (\(_, _, mem, _) -> memSrcName mem) fromwrites' -- An input memory must not be an output memory, and vice versa. mem_ins = S.difference mem_ins0 mem_outs0 mem_outs = S.difference mem_outs0 mem_ins0 exceptions = snd $ evalRWS (findExceptions fus fus_result lus mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes output_vars) () S.empty in exceptions where findExceptions :: FirstUses -> FirstUses -> LastUses -> Names -> Names -> M.Map VName (Slice SubExp) -> M.Map VName ExpMem.IxFun -> M.Map VName PrimType -> [VName] -> RWS () [(VName, VName)] LocalDeaths () findExceptions fus fus_result lus mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes output_vars = do forM_ stms $ \(Let (Pattern _patctxelems patvalelems) _ _) -> do let vs = map patElemName patvalelems fus_stm = S.unions $ map (`lookupEmptyable` fus) vs lus_stm = S.unions $ map ((`lookupEmptyable` lus) . FromStm) vs recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_stm modify $ S.union lus_stm forM_ output_vars $ \ov -> do let fus_ov = lookupEmptyable ov fus_result recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_ov recordNewExceptions :: Names -> Names -> M.Map VName (Slice SubExp) -> M.Map VName ExpMem.IxFun -> M.Map VName PrimType -> Names -> RWS () [(VName, VName)] LocalDeaths () recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_cur = do deaths <- get forM_ (S.toList fus_cur) $ \mem_fu -> forM_ deaths $ \mem_killed -> fromMaybe (return ()) $ do slice_fu <- M.lookup mem_fu mem_slices slice_killed <- M.lookup mem_killed mem_slices ixfun_fu <- M.lookup mem_fu mem_ixfuns ixfun_killed <- M.lookup mem_killed mem_ixfuns pt_fu <- M.lookup mem_fu mem_primtypes pt_killed <- M.lookup mem_killed mem_primtypes return $ when ( -- Is the killed memory read from and the first use memory -- written to? mem_fu `S.member` mem_outs && mem_killed `S.member` mem_ins && -- Same index functions? ixfun_fu == ixfun_killed && -- too conservative? -- Same slices? slice_fu == slice_killed && -- Same primitive type byte sizes? (primByteSize pt_fu :: Int) == primByteSize pt_killed ) $ tell [(mem_fu, mem_killed)] -- Memory blocks that have had their last use locally in the body. type LocalDeaths = Names