-- | This module implements an optimization that migrates host -- statements into 'GPUBody' kernels to reduce the number of -- host-device synchronizations that occur when a scalar variable is -- written to or read from device memory. Which statements that should -- be migrated are determined by a 'MigrationTable' produced by the -- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable" module; this module -- merely performs the migration and rewriting dictated by that table. module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where import Control.Monad import Control.Monad.Trans.Class import qualified Control.Monad.Trans.Reader as R import Control.Monad.Trans.State.Strict hiding (State) import Control.Parallel.Strategies (parMap, rpar) import Data.Foldable import qualified Data.IntMap.Strict as IM import Data.List (unzip4, zip4) import qualified Data.Map.Strict as M import Data.Sequence ((<|), (><), (|>)) import qualified Data.Text as T import Futhark.Construct (fullSlice, sliceDim) import Futhark.Error import qualified Futhark.FreshNames as FN import Futhark.IR.GPU import Futhark.MonadFreshNames (VNameSource, getNameSource, putNameSource) import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable import Futhark.Pass import Futhark.Transform.Substitute -- | An optimization pass that migrates host statements into 'GPUBody' kernels -- to reduce the number of host-device synchronizations. reduceDeviceSyncs :: Pass GPU GPU reduceDeviceSyncs = Pass "reduce device synchronizations" "Move host statements to device to reduce blocking memory operations." run where run prog = do ns <- getNameSource let mt = analyseProg prog let st = initialState ns let (prog', st') = R.runReader (runStateT (optimizeProgram prog) st) mt putNameSource (stateNameSource st') pure prog' -------------------------------------------------------------------------------- -- AD HOC OPTIMIZATION -- -------------------------------------------------------------------------------- -- | Optimize a whole program. The type signatures of top-level functions will -- remain unchanged. optimizeProgram :: Prog GPU -> ReduceM (Prog GPU) optimizeProgram (Prog consts funs) = do consts' <- optimizeStms consts funs' <- sequence $ parMap rpar optimizeFunDef funs pure (Prog consts' funs') -- | Optimize a function definition. Its type signature will remain unchanged. optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU) optimizeFunDef fd = do let body = funDefBody fd stms' <- optimizeStms (bodyStms body) pure $ fd {funDefBody = body {bodyStms = stms'}} -- | Optimize a body. Scalar results may be replaced with single-element arrays. optimizeBody :: Body GPU -> ReduceM (Body GPU) optimizeBody (Body _ stms res) = do stms' <- optimizeStms stms res' <- resolveResult res pure (Body () stms' res') -- | Optimize a sequence of statements. optimizeStms :: Stms GPU -> ReduceM (Stms GPU) optimizeStms = foldM optimizeStm mempty -- | Optimize a single statement, rewriting it into one or more statements to -- be appended to the provided 'Stms'. Only variables with continued host usage -- will remain in scope if their statement is migrated. optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU) optimizeStm out stm = do move <- asks (shouldMoveStm stm) if move then moveStm out stm else case stmExp stm of BasicOp (Update safety arr slice (Var v)) | Just _ <- sliceIndices slice -> do -- Rewrite the Update if its write value has been migrated. Copying -- is faster than doing a synchronous write, so we use the device -- array even if the value has been made available to the host. dev <- storedScalar (Var v) case dev of Nothing -> pure (out |> stm) Just dst -> do -- Transform the single element Update into a slice Update. let dims = unSlice slice let (outer, [DimFix i]) = splitAt (length dims - 1) dims let one = intConst Int64 1 let slice' = Slice $ outer ++ [DimSlice i one one] let e = BasicOp (Update safety arr slice' (Var dst)) let stm' = stm {stmExp = e} pure (out |> stm') BasicOp (Replicate (Shape dims) (Var v)) | Pat [PatElem n arr_t] <- stmPat stm -> do -- A Replicate can be rewritten to not require its replication value -- to be available on host. If its value is migrated the Replicate -- thus needs to be transformed. -- -- If the inner dimension of the replication array is one then the -- rewrite can be performed more efficiently than the general case. v' <- resolveName v let v_kept_on_device = v /= v' gpubody_ok <- gets stateGPUBodyOk case v_kept_on_device of False -> pure (out |> stm) True | all (== intConst Int64 1) dims, Just t' <- peelArray 1 arr_t, gpubody_ok -> do let n' = VName (baseName n `withSuffix` "_inner") 0 let pat' = Pat [PatElem n' t'] let e' = BasicOp $ Replicate (Shape $ tail dims) (Var v) let stm' = Let pat' (stmAux stm) e' -- `gpu { v }` is slightly faster than `replicate 1 v` and -- can merge with the GPUBody that v was computed by. gpubody <- inGPUBody (rewriteStm stm') pure (out |> gpubody {stmPat = stmPat stm}) True | last dims == intConst Int64 1 -> let e' = BasicOp $ Replicate (Shape $ init dims) (Var v') stm' = stm {stmExp = e'} in pure (out |> stm') True -> do n' <- newName n -- v_kept_on_device implies that v is a scalar. let dims' = dims ++ [intConst Int64 1] let arr_t' = Array (elemType arr_t) (Shape dims') NoUniqueness let pat' = Pat [PatElem n' arr_t'] let e' = BasicOp $ Replicate (Shape dims) (Var v') let repl = Let pat' (stmAux stm) e' let aux = StmAux mempty mempty () let slice = map sliceDim (arrayDims arr_t) let slice' = slice ++ [DimFix $ intConst Int64 0] let idx = BasicOp $ Index n' (Slice slice') let index = Let (stmPat stm) aux idx pure (out |> repl |> index) BasicOp {} -> pure (out |> stm) Apply {} -> pure (out |> stm) If cond (Body _ tstms0 tres) (Body _ fstms0 fres) (IfDec btypes sort) -> do -- Rewrite branches. tstms1 <- optimizeStms tstms0 fstms1 <- optimizeStms fstms0 -- Ensure return values and types match if one or both branches -- return a result that now reside on device. let bmerge (res, tstms, fstms) (pe, tr, fr, bt) = do let onHost (Var v) = (v ==) <$> resolveName v onHost _ = pure True tr_on_host <- onHost (resSubExp tr) fr_on_host <- onHost (resSubExp fr) if tr_on_host && fr_on_host then -- No result resides on device ==> nothing to do. pure ((pe, tr, fr, bt) : res, tstms, fstms) else -- Otherwise, ensure both results are migrated. do let t = patElemType pe (tstms', tarr) <- storeScalar tstms (resSubExp tr) t (fstms', farr) <- storeScalar fstms (resSubExp fr) t pe' <- arrayizePatElem pe let bt' = staticShapes1 (patElemType pe') let tr' = tr {resSubExp = Var tarr} let fr' = fr {resSubExp = Var farr} pure ((pe', tr', fr', bt') : res, tstms', fstms') let pes = patElems (stmPat stm) let zipped = zip4 pes tres fres btypes (zipped', tstms2, fstms2) <- foldM bmerge ([], tstms1, fstms1) zipped let (pes', tres', fres', btypes') = unzip4 (reverse zipped') -- Rewrite statement. let tbranch' = Body () tstms2 tres' let fbranch' = Body () fstms2 fres' let e' = If cond tbranch' fbranch' (IfDec btypes' sort) let stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out |> stm') (zip pes pes') DoLoop ps lf b -> do -- Enable the migration of for-in loop variables. (params, lform, body) <- rewriteForIn (ps, lf, b) -- Update statement bound variables and parameters if their values -- have been migrated to device. let lmerge (res, stms) (pe, (Param _ pn pt, pval), MoveToDevice) = do -- Rewrite the bound variable. pe' <- arrayizePatElem pe -- Move the initial value to device if not already there. (stms', arr) <- storeScalar stms pval (fromDecl pt) -- Rewrite the parameter. pn' <- newName pn let pt' = toDecl (patElemType pe') Nonunique let pval' = Var arr let param' = (Param mempty pn' pt', pval') -- Record the migration. Ident pn (fromDecl pt) `movedTo` pn' pure ((pe', param') : res, stms') lmerge _ (_, _, UsedOnHost) = -- Initial loop parameter value and loop result should have -- been made available on host instead. compilerBugS "optimizeStm: unhandled host usage of loop param" lmerge (res, stms) (pe, param, StayOnHost) = pure ((pe, param) : res, stms) mt <- ask let pes = patElems (stmPat stm) let mss = map (\(Param _ n _, _) -> statusOf n mt) params (zipped', out') <- foldM lmerge ([], out) (zip3 pes params mss) let (pes', params') = unzip (reverse zipped') -- Rewrite statement. body' <- optimizeBody body let e' = DoLoop params' lform body' let stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out' |> stm') (zip pes pes') WithAcc inputs lmd -> do let getAcc (Acc a _ _ _) = a getAcc _ = compilerBugS "Type error: WithAcc expression did not return accumulator." let accs = zipWith (\t i -> (getAcc t, i)) (lambdaReturnType lmd) inputs inputs' <- mapM (uncurry optimizeWithAccInput) accs let body = lambdaBody lmd stms' <- optimizeStms (bodyStms body) let rewrite (SubExpRes certs se, t, pe) = do se' <- resolveSubExp se if se == se' then pure (SubExpRes certs se, t, pe) else do pe' <- arrayizePatElem pe let t' = patElemType pe' pure (SubExpRes certs se', t', pe') -- Rewrite non-accumulator results that have been migrated. -- -- Accumulator return values do not map to arrays one-to-one but -- one-to-many. They are not transformed however and can be mapped -- as a no-op. let len = length inputs let (res0, res1) = splitAt len (bodyResult body) let (rts0, rts1) = splitAt len (lambdaReturnType lmd) let pes = patElems (stmPat stm) let (pes0, pes1) = splitAt (length pes - length res1) pes (res1', rts1', pes1') <- unzip3 <$> mapM rewrite (zip3 res1 rts1 pes1) let res' = res0 ++ res1' let rts' = rts0 ++ rts1' let pes' = pes0 ++ pes1' -- Rewrite statement. let body' = Body () stms' res' let lmd' = lmd {lambdaBody = body', lambdaReturnType = rts'} let e' = WithAcc inputs' lmd' let stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out |> stm') (zip pes pes') Op op -> do op' <- optimizeHostOp op pure (out |> stm {stmExp = Op op'}) where addRead stms (pe@(PatElem n _), PatElem dev _) | n == dev = pure stms | otherwise = pe `migratedTo` (dev, stms) -- | Rewrite a for-in loop such that relevant source array reads can be delayed. rewriteForIn :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) -> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) rewriteForIn loop@(_, WhileLoop {}, _) = pure loop rewriteForIn (params, ForLoop i t n elems, body) = do mt <- ask let (elems', stms') = foldr (inline mt) ([], bodyStms body) elems pure (params, ForLoop i t n elems', body {bodyStms = stms'}) where inline mt (x, arr) (arrs, stms) | pn <- paramName x, not (usedOnHost pn mt) = let pt = typeOf x stm = bind (PatElem pn pt) (BasicOp $ index arr pt) in (arrs, stm <| stms) | otherwise = ((x, arr) : arrs, stms) index arr of_type = Index arr $ Slice $ DimFix (Var i) : map sliceDim (arrayDims of_type) -- | Optimize an accumulator input. The 'VName' is the accumulator token. optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU) optimizeWithAccInput _ (shape, arrs, Nothing) = pure (shape, arrs, Nothing) optimizeWithAccInput acc (shape, arrs, Just (op, nes)) = do device_only <- asks (shouldMove acc) if device_only then do op' <- addReadsToLambda op pure (shape, arrs, Just (op', nes)) else do let body = lambdaBody op -- To pass type check neither parameters nor results can change. -- -- op may be used on both host and device so we must avoid introducing -- any GPUBody statements. stms' <- noGPUBody $ optimizeStms (bodyStms body) let op' = op {lambdaBody = body {bodyStms = stms'}} pure (shape, arrs, Just (op', nes)) -- | Optimize a host operation. 'Index' statements are added to kernel code -- that depends on migrated scalars. optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op) optimizeHostOp (SegOp (SegMap lvl space types kbody)) = SegOp . SegMap lvl space types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegRed lvl space ops types kbody)) = do ops' <- mapM addReadsToSegBinOp ops SegOp . SegRed lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegScan lvl space ops types kbody)) = do ops' <- mapM addReadsToSegBinOp ops SegOp . SegScan lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegHist lvl space ops types kbody)) = do ops' <- mapM addReadsToHistOp ops SegOp . SegHist lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SizeOp op) = pure (SizeOp op) optimizeHostOp OtherOp {} = -- These should all have been taken care of in the unstreamGPU pass. compilerBugS "optimizeHostOp: unhandled OtherOp" optimizeHostOp (GPUBody types body) = GPUBody types <$> addReadsToBody body -------------------------------------------------------------------------------- -- COMMON HELPERS -- -------------------------------------------------------------------------------- -- | Append the given string to a name. withSuffix :: Name -> String -> Name withSuffix name sfx = nameFromText $ T.append (nameToText name) (T.pack sfx) -------------------------------------------------------------------------------- -- MIGRATION - TYPES -- -------------------------------------------------------------------------------- -- | The monad used to perform migration-based synchronization reductions. type ReduceM = StateT State (R.Reader MigrationTable) -- | The state used by a 'ReduceM' monad. data State = State { -- | A source to generate new 'VName's from. stateNameSource :: VNameSource, -- | A table of variables in the original program which have been migrated -- to device. Each variable maps to a tuple that describes: -- * 'baseName' of the original variable. -- * Type of the original variable. -- * Name of the single element array holding the migrated value. -- * Whether the original variable still can be used on the host. stateMigrated :: IM.IntMap (Name, Type, VName, Bool), -- | Whether non-migration optimizations may introduce 'GPUBody' kernels at -- the current location. stateGPUBodyOk :: Bool } -------------------------------------------------------------------------------- -- MIGRATION - PRIMITIVES -- -------------------------------------------------------------------------------- -- | An initial state to use when running a 'ReduceM' monad. initialState :: VNameSource -> State initialState ns = State { stateNameSource = ns, stateMigrated = mempty, stateGPUBodyOk = True } -- | Retrieve a function of the current environment. asks :: (MigrationTable -> a) -> ReduceM a asks = lift . R.asks -- | Fetch the value of the environment. ask :: ReduceM MigrationTable ask = lift R.ask -- | Perform non-migration optimizations without introducing any GPUBody -- kernels. noGPUBody :: ReduceM a -> ReduceM a noGPUBody m = do prev <- gets stateGPUBodyOk modify $ \st -> st {stateGPUBodyOk = False} res <- m modify $ \st -> st {stateGPUBodyOk = prev} pure res -- | Produce a fresh name, using the given name as a template. newName :: VName -> ReduceM VName newName n = do st <- get let ns = stateNameSource st let (n', ns') = FN.newName ns n put (st {stateNameSource = ns'}) pure n' -- | Create a 'PatElem' that binds the array of a migrated variable binding. arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type) arrayizePatElem (PatElem n t) = do let name = baseName n `withSuffix` "_dev" dev <- newName (VName name 0) let dev_t = t `arrayOfRow` intConst Int64 1 pure (PatElem dev dev_t) -- | @x `movedTo` arr@ registers that the value of @x@ has been migrated to -- @arr[0]@. movedTo :: Ident -> VName -> ReduceM () movedTo = recordMigration False -- | @x `aliasedBy` arr@ registers that the value of @x@ also is available on -- device as @arr[0]@. aliasedBy :: Ident -> VName -> ReduceM () aliasedBy = recordMigration True -- | @recordMigration host x arr@ records the migration of variable @x@ to -- @arr[0]@. If @host@ then the original binding can still be used on host. recordMigration :: Bool -> Ident -> VName -> ReduceM () recordMigration host (Ident x t) arr = modify $ \st -> let migrated = stateMigrated st entry = (baseName x, t, arr, host) migrated' = IM.insert (baseTag x) entry migrated in st {stateMigrated = migrated'} -- | @pe `migratedTo` (dev, stms)@ registers that the variable @pe@ in the -- original program has been migrated to @dev@ and rebinds the variable if -- deemed necessary, adding an index statement to the given statements. migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU) migratedTo pe (dev, stms) = do used <- asks (usedOnHost $ patElemName pe) if used then patElemIdent pe `aliasedBy` dev >> pure (stms |> bind pe (eIndex dev)) else patElemIdent pe `movedTo` dev >> pure stms -- | @useScalar stms n@ returns a variable that binds the result bound by @n@ -- in the original program. If the variable has been migrated to device and have -- not been copied back to host a new variable binding will be added to the -- provided statements and be returned. useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName) useScalar stms n = do entry <- IM.lookup (baseTag n) <$> gets stateMigrated case entry of Nothing -> pure (stms, n) Just (_, _, _, True) -> pure (stms, n) Just (name, t, arr, _) -> do n' <- newName (VName name 0) let stm = bind (PatElem n' t) (eIndex arr) pure (stms |> stm, n') -- | Create an expression that reads the first element of a 1-dimensional array. eIndex :: VName -> Exp GPU eIndex arr = BasicOp $ Index arr (Slice [DimFix $ intConst Int64 0]) -- | A shorthand for binding a single variable to an expression. bind :: PatElem Type -> Exp GPU -> Stm GPU bind pe = Let (Pat [pe]) (StmAux mempty mempty ()) -- | Returns the array alias of @se@ if it is a variable that has been migrated -- to device. Otherwise returns @Nothing@. storedScalar :: SubExp -> ReduceM (Maybe VName) storedScalar (Constant _) = pure Nothing storedScalar (Var n) = do entry <- IM.lookup (baseTag n) <$> gets stateMigrated pure $ fmap (\(_, _, arr, _) -> arr) entry -- | @storeScalar stms se t@ returns a variable that binds a single element -- array that contains the value of @se@ in the original program. If @se@ is a -- variable that has been migrated to device, its existing array alias will be -- used. Otherwise a new variable binding will be added to the provided -- statements and be returned. @t@ is the type of @se@. storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName) storeScalar stms se t = do entry <- case se of Var n -> IM.lookup (baseTag n) <$> gets stateMigrated _ -> pure Nothing case entry of Just (_, _, arr, _) -> pure (stms, arr) Nothing -> do -- How to most efficiently create an array containing the given value -- depends on whether it is a variable or a constant. Creating a constant -- array is a runtime copy of static memory, while creating an array that -- contains a variable results in a synchronous write. The latter is thus -- replaced with either a mergeable GPUBody kernel or a Replicate. -- -- Whether it makes sense to hoist arrays out of bodies to enable CSE is -- left to the simplifier to figure out. Duplicates will be eliminated if -- a scalar is stored multiple times within a body. -- -- TODO: Enable the simplifier to hoist non-consumed, non-returned arrays -- out of top-level function definitions. All constant arrays -- produced here are in principle top-level hoistable. gpubody_ok <- gets stateGPUBodyOk case se of Var n | gpubody_ok -> do n' <- newName n let stm = bind (PatElem n' t) (BasicOp $ SubExp se) gpubody <- inGPUBody (pure stm) let dev = patElemName $ head $ patElems (stmPat gpubody) pure (stms |> gpubody, dev) Var n -> do pe <- arrayizePatElem (PatElem n t) let shape = Shape [intConst Int64 1] let stm = bind pe (BasicOp $ Replicate shape se) pure (stms |> stm, patElemName pe) _ -> do let n = VName (nameFromString "const") 0 pe <- arrayizePatElem (PatElem n t) let stm = bind pe (BasicOp $ ArrayLit [se] t) pure (stms |> stm, patElemName pe) -- | Map a variable name to itself or, if the variable no longer can be used on -- host, the name of a single element array containing its value. resolveName :: VName -> ReduceM VName resolveName n = do entry <- IM.lookup (baseTag n) <$> gets stateMigrated case entry of Nothing -> pure n Just (_, _, _, True) -> pure n Just (_, _, arr, _) -> pure arr -- | Like 'resolveName' but for a t'SubExp'. Constants are mapped to themselves. resolveSubExp :: SubExp -> ReduceM SubExp resolveSubExp (Var n) = Var <$> resolveName n resolveSubExp cnst = pure cnst -- | Like 'resolveSubExp' but for a 'SubExpRes'. resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes resolveSubExpRes (SubExpRes certs se) = -- Certificates are always read back to host. SubExpRes certs <$> resolveSubExp se -- | Apply 'resolveSubExpRes' to a list of results. resolveResult :: Result -> ReduceM Result resolveResult = mapM resolveSubExpRes -- | Migrate a statement to device, ensuring all its bound variables used on -- host will remain available with the same names. moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU) moveStm out (Let pat aux (BasicOp (ArrayLit [se] t'))) | Pat [PatElem n _] <- pat = do -- Save an 'Index' by rewriting the 'ArrayLit' rather than migrating it. let n' = VName (baseName n `withSuffix` "_inner") 0 let pat' = Pat [PatElem n' t'] let e' = BasicOp (SubExp se) let stm' = Let pat' aux e' gpubody <- inGPUBody (rewriteStm stm') pure (out |> gpubody {stmPat = pat}) moveStm out stm = do -- Move the statement to device. gpubody <- inGPUBody (rewriteStm stm) -- Read non-scalars and scalars that are used on host. let arrs = zip (patElems $ stmPat stm) (patElems $ stmPat gpubody) foldM addRead (out |> gpubody) arrs where addRead stms (pe@(PatElem _ t), PatElem dev dev_t) = let add' e = pure $ stms |> bind pe e add = add' . BasicOp in case arrayRank dev_t of -- Alias non-arrays with their prior name. 0 -> add $ SubExp (Var dev) -- Read all certificates for free. 1 | t == Prim Unit -> add' (eIndex dev) -- Record the device alias of each scalar variable and read them -- if used on host. 1 -> pe `migratedTo` (dev, stms) -- Drop the added dimension of multidimensional arrays. _ -> add $ Index dev (fullSlice dev_t [DimFix $ intConst Int64 0]) -- | Create a GPUBody kernel that executes a single statement and stores its -- results in single element arrays. inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU) inGPUBody m = do (stm, st) <- runStateT m initialRState let prologue = rewritePrologue st let pes = patElems (stmPat stm) pat <- Pat <$> mapM arrayizePatElem pes let aux = StmAux mempty mempty () let types = map patElemType pes let res = map (SubExpRes mempty . Var . patElemName) pes let body = Body () (prologue |> stm) res let e = Op (GPUBody types body) pure (Let pat aux e) -------------------------------------------------------------------------------- -- KERNEL REWRITING - TYPES -- -------------------------------------------------------------------------------- -- The monad used to rewrite (migrated) kernel code. type RewriteM = StateT RState ReduceM -- | The state used by a 'RewriteM' monad. data RState = RState { -- | Maps variables in the original program to names to be used by rewrites. rewriteRenames :: IM.IntMap VName, -- | Statements to be added as a prologue before rewritten statements. rewritePrologue :: Stms GPU } -------------------------------------------------------------------------------- -- KERNEL REWRITING - FUNCTIONS -- -------------------------------------------------------------------------------- -- | An initial state to use when running a 'RewriteM' monad. initialRState :: RState initialRState = RState { rewriteRenames = mempty, rewritePrologue = mempty } -- | Rewrite 'SegBinOp' dependencies to scalars that have been migrated. addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU) addReadsToSegBinOp op = do f' <- addReadsToLambda (segBinOpLambda op) pure (op {segBinOpLambda = f'}) -- | Rewrite 'HistOp' dependencies to scalars that have been migrated. addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU) addReadsToHistOp op = do f' <- addReadsToLambda (histOp op) pure (op {histOp = f'}) -- | Rewrite generic lambda dependencies to scalars that have been migrated. addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU) addReadsToLambda f = do body' <- addReadsToBody (lambdaBody f) pure (f {lambdaBody = body'}) -- | Rewrite generic body dependencies to scalars that have been migrated. addReadsToBody :: Body GPU -> ReduceM (Body GPU) addReadsToBody body = do (body', prologue) <- addReadsHelper body pure body' {bodyStms = prologue >< bodyStms body'} -- | Rewrite kernel body dependencies to scalars that have been migrated. addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU) addReadsToKernelBody kbody = do (kbody', prologue) <- addReadsHelper kbody pure kbody' {kernelBodyStms = prologue >< kernelBodyStms kbody'} -- | Rewrite migrated scalar dependencies within anything. The returned -- statements must be added to the scope of the rewritten construct. addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU) addReadsHelper x = do let from = namesToList (freeIn x) (to, st) <- runStateT (mapM rename from) initialRState let rename_map = M.fromList (zip from to) pure (substituteNames rename_map x, rewritePrologue st) -- | Create a fresh name, registering which name it is a rewrite of. rewriteName :: VName -> RewriteM VName rewriteName n = do n' <- lift (newName n) modify $ \st -> st {rewriteRenames = IM.insert (baseTag n) n' (rewriteRenames st)} pure n' -- | Rewrite all bindings introduced by a body (to ensure they are unique) and -- fix any dependencies that are broken as a result of migration or rewriting. rewriteBody :: Body GPU -> RewriteM (Body GPU) rewriteBody (Body _ stms res) = do stms' <- rewriteStms stms res' <- renameResult res pure (Body () stms' res') -- | Rewrite all bindings introduced by a sequence of statements (to ensure they -- are unique) and fix any dependencies that are broken as a result of migration -- or rewriting. rewriteStms :: Stms GPU -> RewriteM (Stms GPU) rewriteStms = foldM rewriteTo mempty where rewriteTo out stm = do stm' <- rewriteStm stm pure $ case stmExp stm' of Op (GPUBody _ (Body _ stms res)) -> let pes = patElems (stmPat stm') in foldl' bnd (out >< stms) (zip pes res) _ -> out |> stm' bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU bnd out (pe, SubExpRes cs se) | Just t' <- peelArray 1 (typeOf pe) = out |> Let (Pat [pe]) (StmAux cs mempty ()) (BasicOp $ ArrayLit [se] t') | otherwise = out |> Let (Pat [pe]) (StmAux cs mempty ()) (BasicOp $ SubExp se) -- | Rewrite all bindings introduced by a single statement (to ensure they are -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. -- -- NOTE: GPUBody kernels must be rewritten through 'rewriteStms'. rewriteStm :: Stm GPU -> RewriteM (Stm GPU) rewriteStm (Let pat aux e) = do e' <- rewriteExp e pat' <- rewritePat pat aux' <- rewriteStmAux aux pure (Let pat' aux' e') -- | Rewrite all bindings introduced by a pattern (to ensure they are unique) -- and fix any dependencies that are broken as a result of migration or -- rewriting. rewritePat :: Pat Type -> RewriteM (Pat Type) rewritePat pat = Pat <$> mapM rewritePatElem (patElems pat) -- | Rewrite the binding introduced by a single pattern element (to ensure it is -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewritePatElem :: PatElem Type -> RewriteM (PatElem Type) rewritePatElem (PatElem n t) = do n' <- rewriteName n t' <- renameType t pure (PatElem n' t') -- | Fix any 'StmAux' certificate references that are broken as a result of -- migration or rewriting. rewriteStmAux :: StmAux () -> RewriteM (StmAux ()) rewriteStmAux (StmAux certs attrs _) = do certs' <- renameCerts certs pure (StmAux certs' attrs ()) -- | Rewrite the bindings introduced by an expression (to ensure they are -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewriteExp :: Exp GPU -> RewriteM (Exp GPU) rewriteExp = mapExpM $ Mapper { mapOnSubExp = renameSubExp, mapOnBody = const rewriteBody, mapOnVName = rename, mapOnRetType = renameExtType, mapOnBranchType = renameExtType, mapOnFParam = rewriteParam, mapOnLParam = rewriteParam, mapOnOp = const opError } where -- This indicates that something fundamentally is wrong with the migration -- table produced by the "Futhark.Analysis.MigrationTable" module. opError = compilerBugS "Cannot migrate a host-only operation to device." -- | Rewrite the binding introduced by a single parameter (to ensure it is -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u)) rewriteParam (Param attrs n t) = do n' <- rewriteName n t' <- renameType t pure (Param attrs n' t') -- | Return the name to use for a rewritten dependency. rename :: VName -> RewriteM VName rename n = do st <- get let renames = rewriteRenames st let idx = baseTag n case IM.lookup idx renames of Just n' -> pure n' _ -> do let stms = rewritePrologue st (stms', n') <- lift $ useScalar stms n modify $ \st' -> st' { rewriteRenames = IM.insert idx n' renames, rewritePrologue = stms' } pure n' -- | Update the variable names within a 'Result' to account for migration and -- rewriting. renameResult :: Result -> RewriteM Result renameResult = mapM renameSubExpRes -- | Update the variable names within a 'SubExpRes' to account for migration and -- rewriting. renameSubExpRes :: SubExpRes -> RewriteM SubExpRes renameSubExpRes (SubExpRes certs se) = do certs' <- renameCerts certs se' <- renameSubExp se pure (SubExpRes certs' se') -- | Update the variable names of certificates to account for migration and -- rewriting. renameCerts :: Certs -> RewriteM Certs renameCerts cs = Certs <$> mapM rename (unCerts cs) -- | Update any variable name within a t'SubExp' to account for migration and -- rewriting. renameSubExp :: SubExp -> RewriteM SubExp renameSubExp (Var n) = Var <$> rename n renameSubExp se = pure se -- | Update the variable names within a type to account for migration and -- rewriting. renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u) -- Note: mapOnType also maps the VName token of accumulators renameType = mapOnType renameSubExp -- | Update the variable names within an existential type to account for -- migration and rewriting. renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u) -- Note: mapOnExtType also maps the VName token of accumulators renameExtType = mapOnExtType renameSubExp