{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE LambdaCase #-} -- | Transform a function based on a mapping from variable to memory and index -- function: Change every variable in the mapping to its possibly new memory -- block. module Futhark.Optimise.MemoryBlockMerging.MemoryUpdater ( transformFromVarMemMappings ) where import qualified Data.Map.Strict as M import qualified Data.List as L import Data.Maybe (mapMaybe, fromMaybe) import Control.Applicative ((<|>)) import Control.Arrow (second) import Control.Monad import Control.Monad.RWS import Futhark.MonadFreshNames import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory (ExplicitMemorish, ExplicitMemory) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Representation.Kernels.Kernel import Futhark.Optimise.MemoryBlockMerging.Types import Futhark.Optimise.MemoryBlockMerging.Miscellaneous data Context = Context { ctxVarToMem :: VarMemMappings MemoryLoc , ctxVarToMemOrig :: VarMemMappings MName , ctxAllocSizes :: M.Map MName SubExp , ctxAllocSizesOrig :: M.Map MName SubExp , ctxHasMaxedSize :: Bool } deriving (Show) newtype FindM lore a = FindM { unFindM :: RWS Context () (VNameSource, [(MName, VName)]) a } deriving (Monad, Functor, Applicative, MonadReader Context, MonadState (VNameSource, [(MName, VName)])) instance MonadFreshNames (FindM lore) where getNameSource = gets fst putNameSource s = modify $ \(_, m) -> (s, m) modifyMemSizeMapping :: ([(MName, VName)] -> [(MName, VName)]) -> FindM lore () modifyMemSizeMapping f = modify $ second f type LoreConstraints lore = (ExplicitMemorish lore, FullMap lore, BodyAttr lore ~ (), ExpAttr lore ~ ()) coerce :: FindM flore a -> FindM tlore a coerce = FindM . unFindM -- | Transform a function to use new memory blocks. transformFromVarMemMappings :: MonadFreshNames m => VarMemMappings MemoryLoc -> VarMemMappings MName -> M.Map MName SubExp -> M.Map MName SubExp -> Bool -> FunDef ExplicitMemory -> m (FunDef ExplicitMemory) transformFromVarMemMappings var_to_mem var_to_mem_orig alloc_sizes alloc_sizes_orig has_maxed_size fundef = let m = unFindM $ transformFunDefBody $ funDefBody fundef ctx = Context { ctxVarToMem = var_to_mem , ctxVarToMemOrig = var_to_mem_orig , ctxAllocSizes = alloc_sizes , ctxAllocSizesOrig = alloc_sizes_orig , ctxHasMaxedSize = has_maxed_size } in modifyNameSource (\src -> let (body', (src', _), ()) = runRWS m ctx (src, []) in (fundef { funDefBody = body' }, src') ) transformFunDefBody :: LoreConstraints lore => Body lore -> FindM lore (Body lore) transformFunDefBody (Body () bnds res) = do bnds' <- mapM transformStm $ stmsToList bnds res' <- transformFunDefBodyResult res return $ Body () (stmsFromList bnds') res' transformFunDefBodyResult :: [SubExp] -> FindM lore [SubExp] transformFunDefBodyResult ses = do var_to_mem_orig <- asks ctxVarToMemOrig var_to_mem <- asks ctxVarToMem mem_to_size_orig <- asks ctxAllocSizesOrig mem_to_size <- asks ctxAllocSizes mem_to_new_size <- gets snd let check se | Var v <- se , Just orig <- M.lookup v var_to_mem_orig , Just new <- memLocName <$> M.lookup v var_to_mem = ((Var orig, Nothing), Var new) : case (M.lookup orig mem_to_size_orig, (Var <$> L.lookup new mem_to_new_size) <|> M.lookup new mem_to_size) of (Just size_orig, Just size_new) -> [((size_orig, Just (Var orig)), size_new)] _ -> [] | otherwise = [] check_size_only se | Var v <- se , Just orig <- M.lookup v mem_to_size_orig , Just new <- (Var <$> L.lookup v mem_to_new_size) <|> M.lookup v mem_to_size , orig /= new = [((orig, Just (Var v)), new)] | otherwise = [] mem_orig_to_new1 = concatMap check ses mem_orig_to_new2 = concatMap check_size_only ses mem_orig_to_new = mem_orig_to_new1 ++ mem_orig_to_new2 return $ zipWith ( \se ts -> fromMaybe se ( -- FIXME: This assumes that a memory block always -- comes just after its size variable. We ought -- to instead properly find this information from -- the funDefRetType 'ExtSize's. (se, Nothing) `L.lookup` mem_orig_to_new <|> case ts of (ts0 : _) -> (se, Just ts0) `L.lookup` mem_orig_to_new _ -> Nothing ) ) ses (L.tail $ L.tails ses) transformBody :: LoreConstraints lore => Body lore -> FindM lore (Body lore) transformBody (Body () bnds res) = do bnds' <- mapM transformStm $ stmsToList bnds return $ Body () (stmsFromList bnds') res transformKernelBody :: LoreConstraints lore => KernelBody lore -> FindM lore (KernelBody lore) transformKernelBody (KernelBody () bnds res) = do bnds' <- mapM transformStm $ stmsToList bnds return $ KernelBody () (stmsFromList bnds') res transformMemInfo :: ExpMem.MemInfo d u ExpMem.MemReturn -> MemoryLoc -> ExpMem.MemInfo d u ExpMem.MemReturn transformMemInfo meminfo memloc = case meminfo of ExpMem.MemArray pt shape u _memreturn -> let extixfun = ExpMem.existentialiseIxFun [] $ memLocIxFun memloc in ExpMem.MemArray pt shape u (ExpMem.ReturnsInBlock (memLocName memloc) extixfun) _ -> meminfo data BranchReturn = ExistingBranchReturn ExpMem.BodyReturns | NewBranchReturn (Int -> ExpMem.BodyReturns) VName VName VName transformStm :: LoreConstraints lore => Stm lore -> FindM lore (Stm lore) transformStm (Let (Pattern patctxelems patvalelems) aux e) = do patvalelems' <- mapM transformPatValElem patvalelems e' <- fullMapExpM mapper mapper_kernel e var_to_mem <- asks ctxVarToMem var_to_mem_orig <- asks ctxVarToMemOrig mem_to_size <- asks ctxAllocSizes mem_to_new_size <- gets snd (e'', patctxelems') <- case e' of If cond body_then body_else (IfAttr rets sort) -> do let bodyVarMemLocs body = map (flip M.lookup var_to_mem <=< subExpVar) $ drop (length patctxelems) $ bodyResult body -- FIXME: This is a mess. We try to "reverse-engineer" the origin of -- how the If results came to look as they do, so that we can produce -- a correct IfAttr. findBodyResMem i body_results = let imem = patElemName (patctxelems L.!! i) matching_var = mapMaybe ( \(p, p_i) -> case patElemAttr p of ExpMem.MemArray _ _ _ (ExpMem.ArrayIn vmem _) -> if imem == vmem then Just p_i else Nothing _ -> Nothing ) (zip patvalelems [0..]) in do j <- case matching_var of [t] -> Just t _ -> Nothing body_res_var <- subExpVar (body_results L.!! (length patctxelems + j)) MemoryLoc mem _ixfun <- M.lookup body_res_var var_to_mem return mem fixBodyExistentials body = body { bodyResult = zipWith (\res i -> if i < length patctxelems then maybe res Var $ findBodyResMem i (bodyResult body) else res) (bodyResult body) [0..] } let ms_then = bodyVarMemLocs body_then ms_else = bodyVarMemLocs body_else -- Fix values. let rets' = if ms_then == ms_else then zipWith (\r m -> case m of Nothing -> r Just m' -> transformMemInfo r m' ) rets ms_then else rets let body_then' = fixBodyExistentials body_then body_else' = fixBodyExistentials body_else -- Fix existential memory blocks. let mem_size mem = L.lookup mem mem_to_new_size <|> (subExpVar =<< M.lookup mem mem_to_size) v_size v = do mem <- M.lookup v (M.map memLocName var_to_mem) <|> M.lookup v var_to_mem_orig mem_size mem has_maxed_size <- asks ctxHasMaxedSize let rets_branch_returns = L.zipWith4 (\r pat th el -> case (r, pat, th, el) of (ExpMem.MemArray pt shape u (ExpMem.ReturnsNewBlock space n (Free (Var _size)) extixfun), PatElem _ (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn patmem _)), Var v_th, Var v_el) -> case (v_size v_th, v_size v_el) of (Just s_th, Just s_el) -> if not has_maxed_size --s_th == s_el || not has_maxed_size then ExistingBranchReturn r else NewBranchReturn (\nth_ctxelem -> ExpMem.MemArray pt shape u (ExpMem.ReturnsNewBlock space n (Ext nth_ctxelem) extixfun)) s_th s_el patmem _ -> error ("both branch return arrays should use a memory block with a size: " ++ show v_th ++ " and " ++ show v_el) _ -> ExistingBranchReturn r ) rets' patvalelems (drop (length patctxelems) (bodyResult body_then')) (drop (length patctxelems) (bodyResult body_else')) patctxelems_new <- replicateM (length (filter (\case NewBranchReturn{} -> True ExistingBranchReturn{} -> False ) rets_branch_returns)) (newVName "new_memory_size") let (rets'', _, body_ext_new, _, patmem_to_new_size) = foldl (\(prev, i, ext, patctxelems_new', mapping) rb -> case rb of ExistingBranchReturn r -> (prev ++ [r], i, ext, patctxelems_new', mapping) NewBranchReturn rf s_th s_el patmem -> (prev ++ [rf i], i + 1, ext ++ [(s_th, s_el)], tail patctxelems_new', mapping ++ [(patmem, head patctxelems_new')]) ) ([], length patctxelems, [], patctxelems_new, []) rets_branch_returns modifyMemSizeMapping (++ patmem_to_new_size) let (th_ext_new, el_ext_new) = unzip body_ext_new body_then'' = body_then' { bodyResult = take (length patctxelems) (bodyResult body_then') ++ map Var th_ext_new ++ drop (length patctxelems) (bodyResult body_then') } body_else'' = body_else' { bodyResult = take (length patctxelems) (bodyResult body_else') ++ map Var el_ext_new ++ drop (length patctxelems) (bodyResult body_else') } patctxelems_replaced = map (\pe -> case pe of PatElem name (ExpMem.MemMem _size space) -> case L.lookup name patmem_to_new_size of Just size_new -> PatElem name (ExpMem.MemMem (Var size_new) space) Nothing -> pe _ -> pe ) patctxelems patctxelems' = patctxelems_replaced ++ map (\v -> PatElem v (ExpMem.MemPrim (IntType Int64))) patctxelems_new return (If cond body_then'' body_else'' (IfAttr rets'' sort), patctxelems') DoLoop mergectxparams mergevalparams loopform body -> do -- More special loop handling because of its extra -- pattern-like info. mergectxparams' <- mapM (transformMergeCtxParam mergevalparams) mergectxparams mergevalparams' <- mapM transformMergeValParam mergevalparams -- The body of a loop can return a memory block in its results. This is -- the memory block used by a variable which is also part of the results. -- If the memory block of that variable is changed, we need a way to -- record that the memory block in the body result also needs to change. let zipped = zip [(0::Int)..] (patctxelems ++ patvalelems) findMemLinks (i, PatElem _x (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _))) = case L.find (\(_, PatElem ymem _) -> ymem == xmem) zipped of Just (j, _) -> Just (j, i) Nothing -> Nothing findMemLinks _ = Nothing mem_links = mapMaybe findMemLinks zipped res = bodyResult body fixResRecord i se | Var _mem <- se , Just j <- L.lookup i mem_links , Var related_var <- res L.!! j , Just mem_new <- M.lookup related_var var_to_mem = Var $ memLocName mem_new | otherwise = se res' = zipWith fixResRecord [(0::Int)..] res body' = body { bodyResult = res' } loopform' <- case loopform of ForLoop i it bound loop_vars -> ForLoop i it bound <$> mapM transformForLoopVar loop_vars WhileLoop _ -> return loopform return (DoLoop mergectxparams' mergevalparams' loopform' body', patctxelems) _ -> return (e', patctxelems) return (Let (Pattern patctxelems' patvalelems') aux e'') where mapper = identityMapper { mapOnBody = const transformBody , mapOnFParam = transformFParam , mapOnLParam = transformLParam } mapper_kernel = identityKernelMapper { mapOnKernelBody = coerce . transformBody , mapOnKernelKernelBody = coerce . transformKernelBody , mapOnKernelLambda = coerce . transformLambda , mapOnKernelLParam = transformLParam } -- Update the actual memory block referred to by a context (existential) memory -- block in a loop. transformMergeCtxParam :: [(FParam ExplicitMemory, SubExp)] -> (FParam ExplicitMemory, SubExp) -> FindM lore (FParam ExplicitMemory, SubExp) transformMergeCtxParam mergevalparams (param@(Param ctxmem ExpMem.MemMem{}), mem) = do var_to_mem <- asks ctxVarToMem let usesCtxMem (Param _ (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn pmem _))) = ctxmem == pmem usesCtxMem _ = False -- If the initial value of a loop merge parameter is a memory block name, -- we may have to update that. If the context memory block is used in an -- array in one of the value merge parameters, see if that array variable -- refers to an array that has been set to reuse a memory block. mem' = fromMaybe mem $ do (_, Var orig_var) <- L.find (usesCtxMem . fst) mergevalparams orig_mem <- M.lookup orig_var var_to_mem return $ Var $ memLocName orig_mem return (param, mem') transformMergeCtxParam _ t = return t transformMergeValParam :: (FParam ExplicitMemory, SubExp) -> FindM lore (FParam ExplicitMemory, SubExp) transformMergeValParam (Param x membound, se) = do membound' <- newMemBound membound x return (Param x membound', se) transformPatValElem :: PatElem ExplicitMemory -> FindM lore (PatElem ExplicitMemory) transformPatValElem (PatElem x membound) = PatElem x <$> newMemBound membound x transformFParam :: LoreConstraints lore => FParam lore -> FindM lore (FParam lore) transformFParam (Param x membound) = Param x <$> newMemBound membound x transformLParam :: LoreConstraints lore => LParam lore -> FindM lore (LParam lore) transformLParam (Param x membound) = Param x <$> newMemBound membound x transformLambda :: LoreConstraints lore => Lambda lore -> FindM lore (Lambda lore) transformLambda (Lambda params body types) = do params' <- mapM transformLParam params body' <- transformBody body return $ Lambda params' body' types transformForLoopVar :: LoreConstraints lore => (LParam lore, VName) -> FindM lore (LParam lore, VName) transformForLoopVar (Param x membound, array) = do membound' <- newMemBound membound x return (Param x membound', array) -- Find a new memory block and index function if they exist. newMemBound :: ExpMem.MemBound u -> VName -> FindM lore (ExpMem.MemBound u) newMemBound membound var = do var_to_mem <- asks ctxVarToMem let membound' | ExpMem.MemArray pt shape u _ <- membound , Just (MemoryLoc mem ixfun) <- M.lookup var var_to_mem = Just $ ExpMem.MemArray pt shape u $ ExpMem.ArrayIn mem ixfun | otherwise = Nothing return $ fromMaybe membound membound'