-- | -- This module implements program analysis to determine which program statements -- the "Futhark.Optimise.ReduceDeviceSyncs" pass should move into 'GPUBody' kernels -- to reduce blocking memory transfers between host and device. The results of -- the analysis is encoded into a 'MigrationTable' which can be queried. -- -- To reduce blocking scalar reads the module constructs a data flow -- dependency graph of program variables (see -- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph") in which -- it finds a minimum vertex cut that separates array reads of scalars -- from transitive usage that cannot or should not be migrated to -- device. -- -- The variables of each partition are assigned a 'MigrationStatus' that states -- whether the computation of those variables should be moved to device or -- remain on host. Due to how the graph is built and the vertex cut is found all -- variables bound by a single statement will belong to the same partition. -- -- The vertex cut contains all variables that will reside in device memory but -- are required by host operations. These variables must be read from device -- memory and cannot be reduced further in number merely by migrating -- statements (subject to the accuracy of the graph model). The model is built -- to reduce the worst-case number of scalar reads; an optimal migration of -- statements depends on runtime data. -- -- Blocking scalar writes are reduced by either turning such writes into -- asynchronous kernels, as is done with scalar array literals and accumulator -- updates, or by transforming host-device writing into device-device copying. -- -- For details on how the graph is constructed and how the vertex cut is found, -- see the master thesis "Reducing Synchronous GPU Memory Transfers" by Philip -- Børgesen (2022). module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable ( -- * Analysis analyseProg, -- * Types MigrationTable, MigrationStatus (..), -- * Query -- | These functions all assume that no parent statement should be migrated. -- That is @shouldMoveStm stm mt@ should return @False@ for every statement -- @stm@ with a body that a queried 'VName' or 'Stm' is nested within, -- otherwise the query result may be invalid. shouldMoveStm, shouldMove, usedOnHost, statusOf, ) where import Control.Monad import Control.Monad.Trans.Class import qualified Control.Monad.Trans.Reader as R import Control.Monad.Trans.State.Strict () import Control.Monad.Trans.State.Strict hiding (State) import Control.Parallel.Strategies (parMap, rpar) import Data.Bifunctor (first, second) import Data.Foldable import qualified Data.IntMap.Strict as IM import qualified Data.IntSet as IS import qualified Data.Map.Strict as M import Data.Maybe (fromJust, fromMaybe, isJust, isNothing) import qualified Data.Sequence as SQ import Data.Set (Set, (\\)) import qualified Data.Set as S import Futhark.Error import Futhark.IR.GPU import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph ( EdgeType (..), Edges (..), Id, IdSet, Result (..), Routing (..), Vertex (..), ) import qualified Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph as MG -------------------------------------------------------------------------------- -- MIGRATION TABLES -- -------------------------------------------------------------------------------- -- | Where the value bound by a name should be computed. data MigrationStatus = -- | The statement that computes the value should be moved to device. -- No host usage of the value will be left after the migration. MoveToDevice | -- | As 'MoveToDevice' but host usage of the value will remain after -- migration. UsedOnHost | -- | The statement that computes the value should remain on host. StayOnHost deriving (Eq, Ord, Show) -- | Identifies -- -- (1) which statements should be moved from host to device to reduce the -- worst case number of blocking memory transfers. -- -- (2) which migrated variables that still will be used on the host after -- all such statements have been moved. newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus) -- | Where should the value bound by this name be computed? statusOf :: VName -> MigrationTable -> MigrationStatus statusOf n (MigrationTable mt) = fromMaybe StayOnHost $ IM.lookup (baseTag n) mt -- | Should this whole statement be moved from host to device? shouldMoveStm :: Stm GPU -> MigrationTable -> Bool shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ (BasicOp (Index _ slice))) mt = statusOf n mt == MoveToDevice || any movedOperand slice where movedOperand (Var op) = statusOf op mt == MoveToDevice movedOperand _ = False shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ (BasicOp _)) mt = statusOf n mt /= StayOnHost shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ Apply {}) mt = statusOf n mt /= StayOnHost shouldMoveStm (Let _ _ (If (Var n) _ _ _)) mt = statusOf n mt == MoveToDevice shouldMoveStm (Let _ _ (DoLoop _ (ForLoop _ _ (Var n) _) _)) mt = statusOf n mt == MoveToDevice shouldMoveStm (Let _ _ (DoLoop _ (WhileLoop n) _)) mt = statusOf n mt == MoveToDevice -- BasicOp and Apply statements might not bind any variables (shouldn't happen). -- If statements might use a constant branch condition. -- For loop statements might use a constant number of iterations. -- HostOp statements cannot execute on device. -- WithAcc statements are never moved in their entirety. shouldMoveStm _ _ = False -- | Should the value bound by this name be computed on device? shouldMove :: VName -> MigrationTable -> Bool shouldMove n mt = statusOf n mt /= StayOnHost -- | Will the value bound by this name be used on host? usedOnHost :: VName -> MigrationTable -> Bool usedOnHost n mt = statusOf n mt /= MoveToDevice -- | Merges two migration tables that are assumed to be disjoint. merge :: MigrationTable -> MigrationTable -> MigrationTable merge (MigrationTable a) (MigrationTable b) = MigrationTable (a `IM.union` b) -------------------------------------------------------------------------------- -- HOST-ONLY FUNCTION ANALYSIS -- -------------------------------------------------------------------------------- -- | Identifies top-level function definitions that cannot be run on the -- device. The application of any such function is host-only. type HostOnlyFuns = Set Name -- | Returns the names of all top-level functions that cannot be called from the -- device. The evaluation of such a function is host-only. hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns hostOnlyFunDefs funs = let names = map funDefName funs call_map = M.fromList $ zip names (map checkFunDef funs) in S.fromList names \\ keysToSet (removeHostOnly call_map) where keysToSet = S.fromAscList . M.keys removeHostOnly cm = let (host_only, cm') = M.partition isHostOnly cm in if M.null host_only then cm' else removeHostOnly $ M.map (checkCalls $ keysToSet host_only) cm' isHostOnly = isNothing -- A function that calls a host-only function is itself host-only. checkCalls hostOnlyFuns (Just calls) | hostOnlyFuns `S.disjoint` calls = Just calls checkCalls _ _ = Nothing -- | 'checkFunDef' returns 'Nothing' if this function definition uses arrays or -- HostOps. Otherwise it returns the names of all applied functions, which may -- include user defined functions that could turn out to be host-only. checkFunDef :: FunDef GPU -> Maybe (Set Name) checkFunDef fun = do checkFParams (funDefParams fun) checkRetTypes (funDefRetType fun) checkBody (funDefBody fun) where hostOnly = Nothing ok = Just () check isArr as = if any isArr as then hostOnly else ok checkFParams = check isArray checkLParams = check (isArray . fst) checkRetTypes = check isArrayType checkPats = check isArray checkLoopForm (ForLoop _ _ _ (_ : _)) = hostOnly checkLoopForm _ = ok checkBody = checkStms . bodyStms checkStms stms = S.unions <$> mapM checkStm stms checkStm (Let (Pat pats) _ e) = checkPats pats >> checkExp e -- Any expression that produces an array is caught by checkPats checkExp (BasicOp (Index _ _)) = hostOnly checkExp (WithAcc _ _) = hostOnly checkExp (Op _) = hostOnly checkExp (Apply fn _ _ _) = Just (S.singleton fn) checkExp (If _ tbranch fbranch _) = do calls1 <- checkBody tbranch calls2 <- checkBody fbranch pure (calls1 <> calls2) checkExp (DoLoop params lform body) = do checkLParams params checkLoopForm lform checkBody body checkExp _ = Just S.empty -------------------------------------------------------------------------------- -- MIGRATION ANALYSIS -- -------------------------------------------------------------------------------- -- | HostUsage identifies scalar variables that are used on host. type HostUsage = [Id] nameToId :: VName -> Id nameToId = baseTag -- | Analyses a program to return a migration table that covers all its -- statements and variables. analyseProg :: Prog GPU -> MigrationTable analyseProg (Prog consts funs) = let hof = hostOnlyFunDefs funs mt = analyseConsts hof funs consts mts = parMap rpar (analyseFunDef hof) funs in foldl' merge mt mts -- | Analyses top-level constants. analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable analyseConsts hof funs consts = let usage = M.foldlWithKey (f $ freeIn funs) [] (scopeOf consts) in analyseStms hof usage consts where f free usage n t | isScalar t, n `nameIn` free = nameToId n : usage | otherwise = usage -- | Analyses a top-level function definition. analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable analyseFunDef hof fd = let body = funDefBody fd usage = foldl' f [] $ zip (bodyResult body) (funDefRetType fd) stms = bodyStms body in analyseStms hof usage stms where f usage (SubExpRes _ (Var n), t) | isScalarType t = nameToId n : usage f usage _ = usage -- | Analyses statements. The 'HostUsage' list identifies which bound scalar -- variables that subsequently may be used on host. All free variables such as -- constants and function parameters are assumed to reside on host. analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable analyseStms hof usage stms = let (g, srcs, _) = buildGraph hof usage stms (routed, unrouted) = srcs (_, g') = MG.routeMany unrouted g -- hereby routed f st' = MG.fold g' visit st' Normal st = foldl' f (initial, MG.none) unrouted (vr, vn, tn) = fst $ foldl' f st routed in -- TODO: Delay reads into (deeper) branches MigrationTable $ IM.unions [ IM.fromSet (const MoveToDevice) vr, IM.fromSet (const MoveToDevice) vn, -- Read by host if not reached by a reversed edge IM.fromSet (const UsedOnHost) tn ] where -- 1) Visited by reversed edge. -- 2) Visited by normal edge, no route. -- 3) Visited by normal edge, had route; will potentially be read by host. initial = (IS.empty, IS.empty, IS.empty) visit (vr, vn, tn) Reversed v = let vr' = IS.insert (vertexId v) vr in (vr', vn, tn) visit (vr, vn, tn) Normal v@Vertex {vertexRouting = NoRoute} = let vn' = IS.insert (vertexId v) vn in (vr, vn', tn) visit (vr, vn, tn) Normal v = let tn' = IS.insert (vertexId v) tn in (vr, vn, tn') -------------------------------------------------------------------------------- -- TYPE HELPERS -- -------------------------------------------------------------------------------- isScalar :: Typed t => t -> Bool isScalar = isScalarType . typeOf isScalarType :: TypeBase shape u -> Bool isScalarType (Prim Unit) = False isScalarType (Prim _) = True isScalarType _ = False isArray :: Typed t => t -> Bool isArray = isArrayType . typeOf isArrayType :: ArrayShape shape => TypeBase shape u -> Bool isArrayType = (0 <) . arrayRank -------------------------------------------------------------------------------- -- GRAPH BUILDING -- -------------------------------------------------------------------------------- buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks) buildGraph hof usage stms = let (g, srcs, sinks) = execGrapher hof (graphStms stms) g' = foldl' (flip MG.connectToSink) g usage in (g', srcs, sinks) -- | Graph a body. graphBody :: Body GPU -> Grapher () graphBody body = do let res_ops = namesIntSet $ freeIn (bodyResult body) body_stats <- captureBodyStats $ incBodyDepthFor (graphStms (bodyStms body) >> tellOperands res_ops) body_depth <- (1 +) <$> getBodyDepth let host_only = IS.member body_depth (bodyHostOnlyParents body_stats) modify $ \st -> let stats = stateStats st hops' = IS.delete body_depth (bodyHostOnlyParents stats) -- If body contains a variable that is required on host the parent -- statement that contains this body cannot be migrated as a whole. stats' = if host_only then stats {bodyHostOnly = True} else stats in st {stateStats = stats' {bodyHostOnlyParents = hops'}} -- | Graph multiple statements. graphStms :: Stms GPU -> Grapher () graphStms = mapM_ graphStm -- | Graph a single statement. graphStm :: Stm GPU -> Grapher () graphStm stm = do let bs = boundBy stm let e = stmExp stm -- IMPORTANT! It is generally assumed that all scalars within types and -- shapes are present on host. Any expression of a type wherein one of its -- scalar operands appears must therefore ensure that that scalar operand is -- marked as a size variable (see the 'hostSize' function). case e of BasicOp (SubExp se) -> do graphSimple bs e one bs `reusesSubExp` se BasicOp (Opaque _ se) -> do graphSimple bs e one bs `reusesSubExp` se BasicOp (ArrayLit arr t) | isScalar t, any (isJust . subExpVar) arr -> -- Migrating an array literal with free variables saves a write for -- every scalar it contains. Under some backends the compiler -- generates asynchronous writes for scalar constants but otherwise -- each write will be synchronous. If all scalars are constants then -- the compiler generates more efficient code that copies static -- device memory. graphAutoMove (one bs) BasicOp UnOp {} -> graphSimple bs e BasicOp BinOp {} -> graphSimple bs e BasicOp CmpOp {} -> graphSimple bs e BasicOp ConvOp {} -> graphSimple bs e BasicOp Assert {} -> -- == OpenCL ============================================================= -- -- The next read after the execution of a kernel containing an assertion -- will be made asynchronous, followed by an asynchronous read to check -- if any assertion failed. The runtime will then block for all enqueued -- operations to finish. -- -- Since an assertion only binds a certificate of unit type, an assertion -- cannot increase the number of (read) synchronizations that occur. In -- this regard it is free to migrate. The synchronization that does occur -- is however (presumably) more expensive as the pipeline of GPU work will -- be flushed. -- -- Since this cost is difficult to quantify and amortize over assertion -- migration candidates (cost depends on ordering of kernels and reads) we -- assume it is insignificant. This will likely hold for a system where -- multiple threads or processes schedules GPU work, as system-wide -- throughput only will decrease if the GPU utilization decreases as a -- result. -- -- == CUDA =============================================================== -- -- Under the CUDA backend every read is synchronous and is followed by -- a full synchronization that blocks for all enqueued operations to -- finish. If any enqueued kernel contained an assertion, another -- synchronous read is then made to check if an assertion failed. -- -- Migrating an assertion to save a read may thus introduce new reads, and -- the total number of reads can hence either decrease, remain the same, -- or even increase, subject to the ordering of reads and kernels that -- perform assertions. -- -- Since it is possible to implement the same failure checking scheme as -- OpenCL using asynchronous reads (and doing so would be a good idea!) -- we consider this to be acceptable. -- -- TODO: Implement the OpenCL failure checking scheme under CUDA. This -- should reduce the number of synchronizations per read to one. graphSimple bs e BasicOp (Index _ slice) | isFixed slice -> graphRead (one bs) BasicOp {} | [(_, t)] <- bs, dims <- arrayDims t, dims /= [], -- i.e. produces an array all (== intConst Int64 1) dims -> -- An expression that produces an array that only contains a single -- primitive value is as efficient to compute and copy as a scalar, -- and introduces no size variables. -- -- This is an exception to the inefficiency rules that comes next. graphSimple bs e -- Expressions with a cost sublinear to the size of their result arrays are -- risky to migrate as we cannot guarantee that their results are not -- returned from a GPUBody, which always copies its return values. Since -- this would make the effective asymptotic cost of such statements linear -- we block them from being migrated on their own. -- -- The parent statement of an enclosing body may still be migrated as a -- whole given that each of its returned arrays either -- 1) is backed by memory used by a migratable statement within its body. -- 2) contains just a single element. -- An array matching either criterion is denoted "copyable memory" because -- the asymptotic cost of copying it is less than or equal to the statement -- that produced it. This makes the parent of statements with sublinear cost -- safe to migrate. BasicOp (Index arr s) -> do graphInefficientReturn (sliceDims s) e one bs `reuses` arr BasicOp (Update _ arr _ _) -> do graphInefficientReturn [] e one bs `reuses` arr BasicOp (FlatIndex arr s) -> do -- Migrating a FlatIndex leads to a memory allocation error. -- -- TODO: Fix FlatIndex memory allocation error. -- -- Can be replaced with 'graphHostOnly e' to disable migration. -- A fix can be verified by enabling tests/migration/reuse2_flatindex.fut graphInefficientReturn (flatSliceDims s) e one bs `reuses` arr BasicOp (FlatUpdate arr _ _) -> do graphInefficientReturn [] e one bs `reuses` arr BasicOp (Scratch _ s) -> -- Migrating a Scratch leads to a memory allocation error. -- -- TODO: Fix Scratch memory allocation error. -- -- Can be replaced with 'graphHostOnly e' to disable migration. -- A fix can be verified by enabling tests/migration/reuse4_scratch.fut graphInefficientReturn s e BasicOp (Reshape s arr) -> do graphInefficientReturn (newDims s) e one bs `reuses` arr BasicOp (Rearrange _ arr) -> do graphInefficientReturn [] e one bs `reuses` arr BasicOp (Rotate _ arr) -> do -- Migrating a Rotate leads to a memory allocation error. -- -- TODO: Fix Rotate memory allocation error. -- -- Can be replaced with 'graphHostOnly e' to disable migration. -- A fix can be verified by enabling tests/migration/reuse7_rotate.fut graphInefficientReturn [] e one bs `reuses` arr -- Expressions with a cost linear to the size of their result arrays are -- inefficient to migrate into GPUBody kernels as such kernels are single- -- threaded. For sufficiently large arrays the cost may exceed what is saved -- by avoiding reads. We therefore also block these from being migrated, -- as well as their parents. BasicOp ArrayLit {} -> -- An array literal purely of primitive constants can be hoisted out to be -- a top-level constant, unless it is to be returned or consumed. -- Otherwise its runtime implementation will copy a precomputed static -- array and thus behave like a 'Copy'. -- Whether the rows are primitive constants or arrays, without any scalar -- variable operands such ArrayLit cannot directly prevent a scalar read. graphHostOnly e BasicOp Concat {} -> -- Is unlikely to prevent a scalar read as the only SubExp operand in -- practice is a computation of host-only size variables. graphHostOnly e BasicOp Copy {} -> -- Only takes an array operand, so cannot directly prevent a scalar read. graphHostOnly e BasicOp Manifest {} -> -- Takes no scalar operands so cannot directly prevent a scalar read. -- It is introduced as part of the BlkRegTiling kernel optimization and -- is thus unlikely to prevent the migration of a parent which was not -- already blocked by some host-only operation. graphHostOnly e BasicOp Iota {} -> graphHostOnly e BasicOp Replicate {} -> graphHostOnly e -- END BasicOp UpdateAcc {} -> graphUpdateAcc (one bs) e Apply fn _ _ _ -> graphApply fn bs e If cond tbody fbody _ -> graphIf bs cond tbody fbody DoLoop params lform body -> graphLoop bs params lform body WithAcc inputs f -> graphWithAcc bs inputs f Op GPUBody {} -> -- A GPUBody can be migrated into a parent GPUBody by replacing it with -- its body statements and binding its return values inside 'ArrayLit's. tellGPUBody Op _ -> graphHostOnly e where one [x] = x one _ = compilerBugS "Type error: unexpected number of pattern elements." isFixed = isJust . sliceIndices -- new_dims may introduce new size variables which must be present on host -- when this expression is evaluated. graphInefficientReturn new_dims e = do mapM_ hostSize new_dims graphedScalarOperands e >>= addEdges ToSink hostSize (Var n) = hostSizeVar n hostSize _ = pure () hostSizeVar = requiredOnHost . nameToId -- | Bindings for all pattern elements bound by a statement. boundBy :: Stm GPU -> [Binding] boundBy = map (\(PatElem n t) -> (nameToId n, t)) . patElems . stmPat -- | Graph a statement which in itself neither reads scalars from device memory -- nor forces such scalars to be available on host. Such statement can be moved -- to device to eliminate the host usage of its operands which transitively may -- depend on a scalar device read. graphSimple :: [Binding] -> Exp GPU -> Grapher () graphSimple bs e = do -- Only add vertices to the graph if they have a transitive dependency to -- an array read. Transitive dependencies through variables connected to -- sinks do not count. ops <- graphedScalarOperands e let edges = MG.declareEdges (map fst bs) unless (IS.null ops) (mapM_ addVertex bs >> addEdges edges ops) -- | Graph a statement that reads a scalar from device memory. graphRead :: Binding -> Grapher () graphRead b = do -- Operands are not important as the source will block routes through b. addSource b tellRead -- | Graph a statement that always should be moved to device. graphAutoMove :: Binding -> Grapher () graphAutoMove = -- Operands are not important as the source will block routes through b. addSource -- | Graph a statement that is unfit for execution in a GPUBody and thus must -- be executed on host, requiring all its operands to be made available there. -- Parent statements of enclosing bodies are also blocked from being migrated. graphHostOnly :: Exp GPU -> Grapher () graphHostOnly e = do -- Connect the vertices of all operands to sinks to mark that they are -- required on host. Transitive reads that they depend upon can be delayed -- no further, and any parent statements cannot be migrated. ops <- graphedScalarOperands e addEdges ToSink ops tellHostOnly -- | Graph an 'UpdateAcc' statement. graphUpdateAcc :: Binding -> Exp GPU -> Grapher () graphUpdateAcc b e | (_, Acc a _ _ _) <- b = -- The actual graphing is delayed to the corrensponding 'WithAcc' parent. modify $ \st -> let accs = stateUpdateAccs st accs' = IM.alter add (nameToId a) accs in st {stateUpdateAccs = accs'} where add Nothing = Just [(b, e)] add (Just xs) = Just $ (b, e) : xs graphUpdateAcc _ _ = compilerBugS "Type error: UpdateAcc did not produce accumulator typed value." -- | Graph a function application. graphApply :: Name -> [Binding] -> Exp GPU -> Grapher () graphApply fn bs e = do hof <- isHostOnlyFun fn if hof then graphHostOnly e else graphSimple bs e -- | Graph an if statement. graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher () graphIf bs cond tbody fbody = do body_host_only <- incForkDepthFor ( do tstats <- captureBodyStats (graphBody tbody) fstats <- captureBodyStats (graphBody fbody) pure $ bodyHostOnly tstats || bodyHostOnly fstats ) -- Record aliases for copyable memory backing returned arrays. may_copy_results <- reusesBranches bs (results tbody) (results fbody) let may_migrate = not body_host_only && may_copy_results cond_id <- case (may_migrate, cond) of (False, Var n) -> -- The migration status of the condition is what determines whether the -- statement may be migrated as a whole or not. See 'shouldMoveStm'. connectToSink (nameToId n) >> pure IS.empty (True, Var n) -> onlyGraphedScalar n (_, _) -> pure IS.empty tellOperands cond_id -- Connect branch results to bound variables to allow delaying reads out of -- branches. It might also be beneficial to move the whole statement to -- device, to avoid reading the branch condition value. This must be balanced -- against the need to read the values bound by the if statement. -- -- By connecting the branch condition to each variable bound by the statement -- the condition will only stay on device if -- -- (1) the if statement is not required on host, based on the statements -- within its body. -- -- (2) no additional reads will be required to use the if statement bound -- variables should the whole statement be migrated. -- -- If the condition is migrated to device and stays there, then the if -- statement must necessarily execute on device. -- -- While the graph model built by this module generally migrates no more -- statements than necessary to obtain a minimum vertex cut, the branches -- of if statements are subject to an inaccuracy. Specifically model is not -- strong enough to capture their mutual exclusivity and thus encodes that -- both branches are taken. While this does not affect the resulting number -- of host-device reads it means that some reads may needlessly be delayed -- out of branches. The overhead as measured on futhark-benchmarks appears -- to be neglible though. ret <- zipWithM (comb cond_id) (bodyResult tbody) (bodyResult fbody) mapM_ (uncurry createNode) (zip bs ret) where results = map resSubExp . bodyResult comb ci a b = (ci <>) <$> onlyGraphedScalars (toSet a <> toSet b) toSet (SubExpRes _ (Var n)) = S.singleton n toSet _ = S.empty ----------------------------------------------------- -- These type aliases are only used by 'graphLoop' -- ----------------------------------------------------- type ReachableBindings = IdSet type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings) type NonExhausted = [Id] type LoopValue = (Binding, Id, SubExp, SubExp) ----------------------------------------------------- ----------------------------------------------------- -- | Graph a loop statement. graphLoop :: [Binding] -> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher () graphLoop [] _ _ _ = -- We expect each loop to bind a value or be eliminated. compilerBugS "Loop statement bound no variable; should have been eliminated." graphLoop (b : bs) params lform body = do -- Graph loop params and body while capturing statistics. g0 <- getGraph stats <- captureBodyStats (subgraphId `graphIdFor` graphTheLoop) -- Record aliases for copyable memory backing returned arrays. -- Does the loop return any arrays which prevent it from being migrated? let args = map snd params let results = map resSubExp (bodyResult body) may_copy_results <- reusesBranches (b : bs) args results -- Connect loop condition to a sink if the loop cannot be migrated. -- The migration status of the condition is what determines whether the -- loop may be migrated as a whole or not. See 'shouldMoveStm'. let may_migrate = not (bodyHostOnly stats) && may_copy_results unless may_migrate $ case lform of ForLoop _ _ (Var n) _ -> connectToSink (nameToId n) WhileLoop n -> connectToSink (nameToId n) _ -> pure () -- Connect graphed return values to their loop parameters. g1 <- getGraph mapM_ (mergeLoopParam g1) loopValues -- Route the sources within the loop body in isolation. -- The loop graph must not be altered after this point. srcs <- routeSubgraph subgraphId -- Graph the variables bound by the statement. forM_ loopValues $ \(bnd, p, _, _) -> createNode bnd (IS.singleton p) -- If a device read is delayed from one iteration to the next the -- corresponding variables bound by the statement must be treated as -- sources. g2 <- getGraph let (dbs, rbc) = foldl' (deviceBindings g2) (IS.empty, MG.none) srcs modifySources $ second (IS.toList dbs <>) -- Connect operands to sinks if they can reach a sink within the loop. -- Otherwise connect them to the loop bound variables that they can -- reach and exhaust their normal entry edges into the loop. -- This means a read can be delayed through a loop but not into it if -- that would increase the number of reads done by any given iteration. let ops = IS.filter (`MG.member` g0) (bodyOperands stats) foldM_ connectOperand rbc (IS.elems ops) -- It might be beneficial to move the whole loop to device, to avoid -- reading the (initial) loop condition value. This must be balanced -- against the need to read the values bound by the loop statement. -- -- For more details see the similar description for if statements. when may_migrate $ case lform of ForLoop _ _ n _ -> onlyGraphedScalarSubExp n >>= addEdges (ToNodes bindings Nothing) WhileLoop n | (_, _, arg, _) <- loopValueFor n -> onlyGraphedScalarSubExp arg >>= addEdges (ToNodes bindings Nothing) where subgraphId :: Id subgraphId = fst b loopValues :: [LoopValue] loopValues = let tmp = zip3 (b : bs) params (bodyResult body) tmp' = flip map tmp $ \(bnd, (p, arg), res) -> let i = nameToId (paramName p) in (bnd, i, arg, resSubExp res) in filter (\((_, t), _, _, _) -> isScalar t) tmp' bindings :: IdSet bindings = IS.fromList $ map (\((i, _), _, _, _) -> i) loopValues loopValueFor :: VName -> LoopValue loopValueFor n = fromJust $ find (\(_, p, _, _) -> p == nameToId n) loopValues graphTheLoop :: Grapher () graphTheLoop = do mapM_ graphParam loopValues -- For simplicity we do not currently track memory reuse through merge -- parameters. A parameter does not simply reuse the memory of its -- argument; it must also consider the iteration return value, which in -- turn may depend on other merge parameters. -- -- Situations that would benefit from this tracking is unlikely to occur -- at the time of writing, and if it occurs current compiler limitations -- will prevent successful compilation. -- Specifically it requires the merge parameter argument to reuse memory -- from an array literal, and both it and the loop must occur within an -- if statement branch. Array literals are generally hoisted out of if -- statements however, and when they are not, a memory allocation error -- occurs. -- -- TODO: Track memory reuse through merge parameters. case lform of ForLoop _ _ n elems -> do onlyGraphedScalarSubExp n >>= tellOperands mapM_ graphForInElem elems WhileLoop _ -> pure () graphBody body where graphForInElem (p, arr) = do when (isScalar p) $ addSource (nameToId $ paramName p, typeOf p) when (isArray p) $ (nameToId (paramName p), typeOf p) `reuses` arr graphParam ((_, t), p, arg, _) = do -- It is unknown whether a read can be delayed via the parameter -- from one iteration to the next, so we have to create a vertex -- even if the initial value never depends on a read. addVertex (p, t) ops <- onlyGraphedScalarSubExp arg addEdges (MG.oneEdge p) ops mergeLoopParam :: Graph -> LoopValue -> Grapher () mergeLoopParam g (_, p, _, res) | Var n <- res, ret <- nameToId n, ret /= p = if MG.isSinkConnected p g then connectToSink ret else addEdges (MG.oneEdge p) (IS.singleton ret) | otherwise = pure () deviceBindings :: Graph -> (ReachableBindings, ReachableBindingsCache) -> Id -> (ReachableBindings, ReachableBindingsCache) deviceBindings g (rb, rbc) i = let (r, rbc') = MG.reduce g bindingReach rbc Normal i in case r of Produced rb' -> (rb <> rb', rbc') _ -> compilerBugS "Migration graph sink could be reached from source after it\ \ had been attempted routed." bindingReach :: ReachableBindings -> EdgeType -> Vertex Meta -> ReachableBindings bindingReach rb _ v | i <- vertexId v, IS.member i bindings = IS.insert i rb | otherwise = rb connectOperand :: ReachableBindingsCache -> Id -> Grapher ReachableBindingsCache connectOperand cache op = do g <- getGraph case MG.lookup op g of Nothing -> pure cache Just v -> case vertexEdges v of ToSink -> pure cache ToNodes es Nothing -> connectOp g cache op es ToNodes _ (Just nx) -> connectOp g cache op nx where connectOp :: Graph -> ReachableBindingsCache -> Id -> -- operand id IdSet -> -- its edges Grapher ReachableBindingsCache connectOp g rbc i es = do let (res, nx, rbc') = findBindings g (IS.empty, [], rbc) (IS.elems es) case res of FoundSink -> connectToSink i Produced rb -> modifyGraph $ MG.adjust (updateEdges nx rb) i pure rbc' updateEdges :: NonExhausted -> ReachableBindings -> Vertex Meta -> Vertex Meta updateEdges nx rb v | ToNodes es _ <- vertexEdges v = let nx' = IS.fromList nx es' = ToNodes (rb <> es) $ Just (rb <> nx') in v {vertexEdges = es'} | otherwise = v findBindings :: Graph -> (ReachableBindings, NonExhausted, ReachableBindingsCache) -> [Id] -> -- current non-exhausted edges (MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache) findBindings _ (rb, nx, rbc) [] = (Produced rb, nx, rbc) findBindings g (rb, nx, rbc) (i : is) | Just v <- MG.lookup i g, Just gid <- metaGraphId (vertexMeta v), gid == subgraphId -- only search the subgraph = let (res, rbc') = MG.reduce g bindingReach rbc Normal i in case res of FoundSink -> (FoundSink, [], rbc') Produced rb' -> findBindings g (rb <> rb', nx, rbc') is | otherwise = -- don't exhaust findBindings g (rb, i : nx, rbc) is -- | Graph a 'WithAcc' statement. graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher () graphWithAcc bs inputs f = do -- Graph the body, capturing 'UpdateAcc' statements for delayed graphing. graphBody (lambdaBody f) -- Graph each accumulator monoid and its associated 'UpdateAcc' statements. mapM_ graph $ zip (lambdaReturnType f) inputs -- Record aliases for the backing memory of each returned array. -- 'WithAcc' statements are never migrated as a whole and always returns -- arrays backed by memory allocated elsewhere. let arrs = concatMap (\(_, as, _) -> map Var as) inputs let res = drop (length inputs) (bodyResult $ lambdaBody f) _ <- reusesReturn bs (arrs ++ map resSubExp res) -- Connect return variables to bound values. No outgoing edge exists -- from an accumulator vertex so skip those. Note that accumulators do -- not map to returned arrays one-to-one but one-to-many. ret <- mapM (onlyGraphedScalarSubExp . resSubExp) res mapM_ (uncurry createNode) $ zip (drop (length arrs) bs) ret where graph (Acc a _ types _, (_, _, comb)) = do let i = nameToId a delayed <- fromMaybe [] <$> gets (IM.lookup i . stateUpdateAccs) modify $ \st -> st {stateUpdateAccs = IM.delete i (stateUpdateAccs st)} graphAcc i types (fst <$> comb) delayed -- Neutral elements must always be made available on host for 'WithAcc' -- to type check. mapM_ connectSubExpToSink $ maybe [] snd comb graph _ = compilerBugS "Type error: WithAcc expression did not return accumulator." -- Graph the operator and all 'UpdateAcc' statements associated with an -- accumulator. -- -- The arguments are the 'Id' for the accumulator token, the element types of -- the accumulator/operator, its combining function if any, and all associated -- 'UpdateAcc' statements outside kernels. graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher () graphAcc i _ _ [] = addSource (i, Prim Unit) -- Only used on device. graphAcc i types op delayed = do -- Accumulators are intended for use within SegOps but in principle the AST -- allows their 'UpdateAcc's to be used outside a kernel. This case handles -- that unlikely situation. env <- ask st <- get -- Collect statistics about the operator statements. let lambda = fromMaybe (Lambda [] (Body () SQ.empty []) []) op let m = graphBody (lambdaBody lambda) let stats = R.runReader (evalStateT (captureBodyStats m) st) env -- We treat GPUBody kernels as host-only to not bother rewriting them inside -- operators and to simplify the analysis. They are unlikely to occur anyway. -- -- NOTE: Performance may degrade if a GPUBody is replaced with its contents -- but the containing operator is used on host. let host_only = bodyHostOnly stats || bodyHasGPUBody stats -- op operands are read from arrays and written back so if any of the operands -- are scalar then a read can be avoided by moving the UpdateAcc usages to -- device. If the op itself performs scalar reads its UpdateAcc usages should -- also be moved. let does_read = bodyReads stats || any isScalar types -- Determine which external variables the operator depends upon. -- 'bodyOperands' cannot be used as it might exclude operands that were -- connected to sinks within the body, so instead we create an artifical -- expression to capture graphed operands from. ops <- graphedScalarOperands (WithAcc [] lambda) case (host_only, does_read) of (True, _) -> do -- If the operator cannot run well in a GPUBody then all non-kernel -- UpdateAcc statements are host-only. The current analysis is ignorant -- of what happens inside kernels so we must assume that the operator -- is used within a kernel, meaning that we cannot migrate its statements. -- -- TODO: Improve analysis if UpdateAcc ever is used outside kernels. mapM_ (graphHostOnly . snd) delayed addEdges ToSink ops (_, True) -> do -- Migrate all accumulator usage to device to avoid reads and writes. mapM_ (graphAutoMove . fst) delayed addSource (i, Prim Unit) _ -> do -- Only migrate operator and UpdateAcc statements if it can allow their -- operands to be migrated. createNode (i, Prim Unit) ops forM_ delayed $ \(b, e) -> graphedScalarOperands e >>= createNode b . IS.insert i -- Returns for an expression all scalar operands that must be made available -- on host to execute the expression there. graphedScalarOperands :: Exp GPU -> Grapher Operands graphedScalarOperands e = let is = fst $ execState (collect e) initial in IS.intersection is <$> getGraphedScalars where initial = (IS.empty, S.empty) -- scalar operands, accumulator tokens captureName n = modify $ first $ IS.insert (nameToId n) captureAcc a = modify $ second $ S.insert a collectFree x = mapM_ captureName (namesToList $ freeIn x) collect b@BasicOp {} = collectBasic b collect (Apply _ params _ _) = mapM_ (collectSubExp . fst) params collect (If cond tbranch fbranch _) = collectSubExp cond >> collectBody tbranch >> collectBody fbranch collect (DoLoop params lform body) = do mapM_ (collectSubExp . snd) params collectLForm lform collectBody body collect (WithAcc accs f) = collectWithAcc accs f collect (Op op) = collectHostOp op collectBasic (BasicOp (Update _ _ slice _)) = -- Writing a scalar to an array can be replaced with copying a single- -- element slice. If the scalar originates from device memory its read -- can thus be prevented without requiring the 'Update' to be migrated. collectFree slice collectBasic (BasicOp (Replicate shape _)) = -- The replicate of a scalar can be rewritten as a replicate of a single -- element array followed by a slice index. collectFree shape collectBasic e' = -- Note: Plain VName values only refer to arrays. walkExpM (identityWalker {walkOnSubExp = collectSubExp}) e' collectSubExp (Var n) = captureName n collectSubExp _ = pure () collectBody body = do collectStms (bodyStms body) collectFree (bodyResult body) collectStms = mapM_ collectStm collectStm (Let pat _ ua) | BasicOp UpdateAcc {} <- ua, Pat [pe] <- pat, Acc a _ _ _ <- typeOf pe = -- Capture the tokens of accumulators used on host. captureAcc a >> collectBasic ua collectStm stm = collect (stmExp stm) collectLForm (ForLoop _ _ b _) = collectSubExp b -- WhileLoop condition is declared as a loop parameter. collectLForm (WhileLoop _) = pure () -- The collective operands of an operator lambda body are only used on host -- if the associated accumulator is used in an UpdateAcc statement outside a -- kernel. collectWithAcc inputs f = do collectBody (lambdaBody f) used_accs <- gets snd let accs = take (length inputs) (lambdaReturnType f) let used = map (\(Acc a _ _ _) -> S.member a used_accs) accs mapM_ collectAcc (zip used inputs) collectAcc (_, (_, _, Nothing)) = pure () collectAcc (used, (_, _, Just (op, nes))) = do mapM_ collectSubExp nes when used $ collectBody (lambdaBody op) -- Does not collect named operands in -- -- * types and shapes; size variables are assumed available to the host. -- -- * use by a kernel body. -- -- All other operands are conservatively collected even if they generally -- appear to be size variables or results computed by a SizeOp. collectHostOp (SegOp (SegMap lvl sp _ _)) = do collectSegLevel lvl collectSegSpace sp collectHostOp (SegOp (SegRed lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops collectHostOp (SegOp (SegScan lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops collectHostOp (SegOp (SegHist lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectHistOp ops collectHostOp (SizeOp op) = collectFree op collectHostOp (OtherOp op) = collectFree op collectHostOp GPUBody {} = pure () collectSegLevel (SegThread (Count num) (Count size) _) = collectSubExp num >> collectSubExp size collectSegLevel (SegGroup (Count num) (Count size) _) = collectSubExp num >> collectSubExp size collectSegSpace space = mapM_ collectSubExp (segSpaceDims space) collectSegBinOp (SegBinOp _ _ nes _) = mapM_ collectSubExp nes collectHistOp (HistOp _ rf _ nes _ _) = do collectSubExp rf mapM_ collectSubExp nes -------------------------------------------------------------------------------- -- GRAPH BUILDING - PRIMITIVES -- -------------------------------------------------------------------------------- -- | Creates a vertex for the given binding, provided that the set of operands -- is not empty. createNode :: Binding -> Operands -> Grapher () createNode b ops = unless (IS.null ops) (addVertex b >> addEdges (MG.oneEdge $ fst b) ops) -- | Adds a vertex to the graph for the given binding. addVertex :: Binding -> Grapher () addVertex (i, t) = do meta <- getMeta let v = MG.vertex i meta when (isScalar t) $ modifyGraphedScalars (IS.insert i) when (isArray t) $ recordCopyableMemory i (metaBodyDepth meta) modifyGraph (MG.insert v) -- | Adds a source connected vertex to the graph for the given binding. addSource :: Binding -> Grapher () addSource b = do addVertex b modifySources $ second (fst b :) -- | Adds the given edges to each vertex identified by the 'IdSet'. It is -- assumed that all vertices reside within the body that currently is being -- graphed. addEdges :: Edges -> IdSet -> Grapher () addEdges ToSink is = do modifyGraph $ \g -> IS.foldl' (flip MG.connectToSink) g is modifyGraphedScalars (`IS.difference` is) addEdges es is = do modifyGraph $ \g -> IS.foldl' (flip $ MG.addEdges es) g is tellOperands is -- | Ensure that a variable (which is in scope) will be made available on host -- before its first use. requiredOnHost :: Id -> Grapher () requiredOnHost i = do mv <- MG.lookup i <$> getGraph case mv of Nothing -> pure () Just v -> do connectToSink i tellHostOnlyParent (metaBodyDepth $ vertexMeta v) -- | Connects the vertex of the given id to a sink. connectToSink :: Id -> Grapher () connectToSink i = do modifyGraph (MG.connectToSink i) modifyGraphedScalars (IS.delete i) -- | Like 'connectToSink' but vertex is given by a t'SubExp'. This is a no-op if -- the t'SubExp' is a constant. connectSubExpToSink :: SubExp -> Grapher () connectSubExpToSink (Var n) = connectToSink (nameToId n) connectSubExpToSink _ = pure () -- | Routes all possible routes within the subgraph identified by this id. -- Returns the ids of the source connected vertices that were attempted routed. -- -- Assumption: The subgraph with the given id has just been created and no path -- exists from it to an external sink. routeSubgraph :: Id -> Grapher [Id] routeSubgraph si = do st <- get let g = stateGraph st let (routed, unrouted) = stateSources st let (gsrcs, unrouted') = span (inSubGraph si g) unrouted let (sinks, g') = MG.routeMany gsrcs g put $ st { stateGraph = g', stateSources = (gsrcs ++ routed, unrouted'), stateSinks = sinks ++ stateSinks st } pure gsrcs -- | @inSubGraph si g i@ returns whether @g@ contains a vertex with id @i@ that -- is declared within the subgraph with id @si@. inSubGraph :: Id -> Graph -> Id -> Bool inSubGraph si g i | Just v <- MG.lookup i g, Just mgi <- metaGraphId (vertexMeta v) = si == mgi inSubGraph _ _ _ = False -- | @b `reuses` n@ records that @b@ binds an array backed by the same memory -- as @n@. If @b@ is not array typed or the backing memory is not copyable then -- this does nothing. reuses :: Binding -> VName -> Grapher () reuses (i, t) n | isArray t = do body_depth <- outermostCopyableArray n case body_depth of Just bd -> recordCopyableMemory i bd Nothing -> pure () | otherwise = pure () reusesSubExp :: Binding -> SubExp -> Grapher () reusesSubExp b (Var n) = b `reuses` n reusesSubExp _ _ = pure () -- @reusesReturn bs res@ records each array binding in @bs@ as reusing copyable -- memory if the corresponding return value in @res@ is backed by copyable -- memory. -- -- If every array binding is registered as being backed by copyable memory then -- the function returns @True@, otherwise it returns @False@. reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool reusesReturn bs res = do body_depth <- metaBodyDepth <$> getMeta foldM (reuse body_depth) True (zip bs res) where reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool reuse body_depth onlyCopyable (b, se) | all (== intConst Int64 1) (arrayDims $ snd b) = -- Single element arrays are immediately recognizable as copyable so -- don't bother recording those. Note that this case also matches -- primitive return values. pure onlyCopyable | (i, t) <- b, isArray t, Var n <- se = do res_body_depth <- outermostCopyableArray n case res_body_depth of Just inner -> do recordCopyableMemory i (min body_depth inner) let returns_free_var = inner <= body_depth pure (onlyCopyable && not returns_free_var) _ -> pure False | otherwise = pure onlyCopyable -- @reusesBranches bs b1 b2@ records each array binding in @bs@ as reusing -- copyable memory if each corresponding return value in the lists @b1@ and @b2@ -- are backed by copyable memory. -- -- If every array binding is registered as being backed by copyable memory then -- the function returns @True@, otherwise it returns @False@. reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool reusesBranches bs b1 b2 = do body_depth <- metaBodyDepth <$> getMeta foldM (reuse body_depth) True $ zip3 bs b1 b2 where reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool reuse body_depth onlyCopyable (b, se1, se2) | all (== intConst Int64 1) (arrayDims $ snd b) = -- Single element arrays are immediately recognizable as copyable so -- don't bother recording those. Note that this case also matches -- primitive return values. pure onlyCopyable | (i, t) <- b, isArray t, Var n1 <- se1, Var n2 <- se2 = do body_depth_1 <- outermostCopyableArray n1 body_depth_2 <- outermostCopyableArray n2 case (body_depth_1, body_depth_2) of (Just bd1, Just bd2) -> do let inner = min bd1 bd2 recordCopyableMemory i (min body_depth inner) let returns_free_var = inner <= body_depth pure (onlyCopyable && not returns_free_var) _ -> pure False | otherwise = pure onlyCopyable -------------------------------------------------------------------------------- -- GRAPH BUILDING - TYPES -- -------------------------------------------------------------------------------- type Grapher = StateT State (R.Reader Env) data Env = Env { -- | See 'HostOnlyFuns'. envHostOnlyFuns :: HostOnlyFuns, -- | Metadata for the current body being graphed. envMeta :: Meta } -- | A measurement of how many bodies something is nested within. type BodyDepth = Int -- | Metadata on the environment that a variable is declared within. data Meta = Meta { -- | How many if statement branch bodies the variable binding is nested -- within. If a route passes through the edge u->v and the fork depth -- -- 1) increases from u to v, then u is within a conditional branch. -- -- 2) decreases from u to v, then v binds the result of two or more -- branches. -- -- After the graph has been built and routed, this can be used to delay -- reads into deeper branches to reduce their likelihood of manifesting. metaForkDepth :: Int, -- | How many bodies the variable is nested within. metaBodyDepth :: BodyDepth, -- | An id for the subgraph within which the variable exists, defined at -- the body level. A read may only be delayed to a point within its own -- subgraph. metaGraphId :: Maybe Id } -- | Ids for all variables used as an operand. type Operands = IdSet -- | Statistics on the statements within a body and their dependencies. data BodyStats = BodyStats { -- | Whether the body contained any host-only statements. bodyHostOnly :: Bool, -- | Whether the body contained any GPUBody kernels. bodyHasGPUBody :: Bool, -- | Whether the body performed any reads. bodyReads :: Bool, -- | All scalar variables represented in the graph that have been used -- as return values of the body or as operands within it, including those -- that are defined within the body itself. Variables with vertices -- connected to sinks may be excluded. bodyOperands :: Operands, -- | Depth of parent bodies with variables that are required on host. Since -- the variables are required on host, the parent statements of these bodies -- cannot be moved to device as a whole. They are host-only. bodyHostOnlyParents :: IS.IntSet } instance Semigroup BodyStats where (BodyStats ho1 gb1 r1 o1 hop1) <> (BodyStats ho2 gb2 r2 o2 hop2) = BodyStats { bodyHostOnly = ho1 || ho2, bodyHasGPUBody = gb1 || gb2, bodyReads = r1 || r2, bodyOperands = IS.union o1 o2, bodyHostOnlyParents = IS.union hop1 hop2 } instance Monoid BodyStats where mempty = BodyStats { bodyHostOnly = False, bodyHasGPUBody = False, bodyReads = False, bodyOperands = IS.empty, bodyHostOnlyParents = IS.empty } type Graph = MG.Graph Meta -- | All vertices connected from a source, partitioned into those that have -- been attempted routed and those which have not. type Sources = ([Id], [Id]) -- | All terminal vertices of routes. type Sinks = [Id] -- | A captured statement for which graphing has been delayed. type Delayed = (Binding, Exp GPU) -- | The vertex handle for a variable and its type. type Binding = (Id, Type) -- | Array variables backed by memory segments that may be copied, mapped to the -- outermost known body depths that declares arrays backed by a superset of -- those segments. type CopyableMemoryMap = IM.IntMap BodyDepth data State = State { -- | The graph being built. stateGraph :: Graph, -- | All known scalars that have been graphed. stateGraphedScalars :: IdSet, -- | All variables that directly bind scalars read from device memory. stateSources :: Sources, -- | Graphed scalars that are used as operands by statements that cannot be -- migrated. A read cannot be delayed beyond these, so if the statements -- that bind these variables are moved to device, the variables must be read -- from device memory. stateSinks :: Sinks, -- | Observed 'UpdateAcc' host statements to be graphed later. stateUpdateAccs :: IM.IntMap [Delayed], -- | A map of encountered arrays that are backed by copyable memory. -- Trivial instances such as single element arrays are excluded. stateCopyableMemory :: CopyableMemoryMap, -- | Information about the current body being graphed. stateStats :: BodyStats } -------------------------------------------------------------------------------- -- GRAPHER OPERATIONS -- -------------------------------------------------------------------------------- execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks) execGrapher hof m = let s = R.runReader (execStateT m st) env in (stateGraph s, stateSources s, stateSinks s) where env = Env { envHostOnlyFuns = hof, envMeta = Meta { metaForkDepth = 0, metaBodyDepth = 0, metaGraphId = Nothing } } st = State { stateGraph = MG.empty, stateGraphedScalars = IS.empty, stateSources = ([], []), stateSinks = [], stateUpdateAccs = IM.empty, stateCopyableMemory = IM.empty, stateStats = mempty } -- | Execute a computation in a modified environment. local :: (Env -> Env) -> Grapher a -> Grapher a local f = mapStateT (R.local f) -- | Fetch the value of the environment. ask :: Grapher Env ask = lift R.ask -- | Retrieve a function of the current environment. asks :: (Env -> a) -> Grapher a asks = lift . R.asks -- | Register that the body contains a host-only statement. This means its -- parent statement and any parent bodies themselves are host-only. A host-only -- statement should not be migrated, either because it cannot run on device or -- because it would be inefficient to do so. tellHostOnly :: Grapher () tellHostOnly = modify $ \st -> st {stateStats = (stateStats st) {bodyHostOnly = True}} -- | Register that the body contains a GPUBody kernel. tellGPUBody :: Grapher () tellGPUBody = modify $ \st -> st {stateStats = (stateStats st) {bodyHasGPUBody = True}} -- | Register that the current body contains a statement that reads device -- memory. tellRead :: Grapher () tellRead = modify $ \st -> st {stateStats = (stateStats st) {bodyReads = True}} -- | Register that these variables are used as operands within the current body. tellOperands :: IdSet -> Grapher () tellOperands is = modify $ \st -> let stats = stateStats st operands = bodyOperands stats in st {stateStats = stats {bodyOperands = operands <> is}} -- | Register that the current statement with a body at the given body depth is -- host-only. tellHostOnlyParent :: BodyDepth -> Grapher () tellHostOnlyParent body_depth = modify $ \st -> let stats = stateStats st parents = bodyHostOnlyParents stats parents' = IS.insert body_depth parents in st {stateStats = stats {bodyHostOnlyParents = parents'}} -- | Get the graph under construction. getGraph :: Grapher Graph getGraph = gets stateGraph -- | All scalar variables with a vertex representation in the graph. getGraphedScalars :: Grapher IdSet getGraphedScalars = gets stateGraphedScalars -- | Every known array that is backed by a memory segment that may be copied, -- mapped to the outermost known body depth where an array is backed by a -- superset of that segment. -- -- A body where all returned arrays are backed by such memory and are written by -- its own statements will retain its asymptotic cost if migrated as a whole. getCopyableMemory :: Grapher CopyableMemoryMap getCopyableMemory = gets stateCopyableMemory -- | The outermost known body depth for an array backed by the same copyable -- memory as the array with this name. outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth) outermostCopyableArray n = IM.lookup (nameToId n) <$> getCopyableMemory -- | Reduces the variables to just the 'Id's of those that are scalars and which -- have a vertex representation in the graph, excluding those that have been -- connected to sinks. onlyGraphedScalars :: Foldable t => t VName -> Grapher IdSet onlyGraphedScalars vs = do let is = foldl' (\s n -> IS.insert (nameToId n) s) IS.empty vs IS.intersection is <$> getGraphedScalars -- | Like 'onlyGraphedScalars' but for a single 'VName'. onlyGraphedScalar :: VName -> Grapher IdSet onlyGraphedScalar n = do let i = nameToId n gss <- getGraphedScalars if IS.member i gss then pure (IS.singleton i) else pure IS.empty -- | Like 'onlyGraphedScalars' but for a single t'SubExp'. onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet onlyGraphedScalarSubExp (Constant _) = pure IS.empty onlyGraphedScalarSubExp (Var n) = onlyGraphedScalar n -- | Update the graph under construction. modifyGraph :: (Graph -> Graph) -> Grapher () modifyGraph f = modify $ \st -> st {stateGraph = f (stateGraph st)} -- | Update the contents of the graphed scalar set. modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher () modifyGraphedScalars f = modify $ \st -> st {stateGraphedScalars = f (stateGraphedScalars st)} -- | Update the contents of the copyable memory map. modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher () modifyCopyableMemory f = modify $ \st -> st {stateCopyableMemory = f (stateCopyableMemory st)} -- | Update the set of source connected vertices. modifySources :: (Sources -> Sources) -> Grapher () modifySources f = modify $ \st -> st {stateSources = f (stateSources st)} -- | Record that this variable binds an array that is backed by copyable -- memory shared by an array at this outermost body depth. recordCopyableMemory :: Id -> BodyDepth -> Grapher () recordCopyableMemory i bd = modifyCopyableMemory (IM.insert i bd) -- | Increment the fork depth for variables graphed by this action. incForkDepthFor :: Grapher a -> Grapher a incForkDepthFor = local $ \env -> let meta = envMeta env fork_depth = metaForkDepth meta in env {envMeta = meta {metaForkDepth = fork_depth + 1}} -- | Increment the body depth for variables graphed by this action. incBodyDepthFor :: Grapher a -> Grapher a incBodyDepthFor = local $ \env -> let meta = envMeta env body_depth = metaBodyDepth meta in env {envMeta = meta {metaBodyDepth = body_depth + 1}} -- | Change the graph id for variables graphed by this action. graphIdFor :: Id -> Grapher a -> Grapher a graphIdFor i = local $ \env -> let meta = envMeta env in env {envMeta = meta {metaGraphId = Just i}} -- | Capture body stats produced by the given action. captureBodyStats :: Grapher a -> Grapher BodyStats captureBodyStats m = do stats <- gets stateStats modify $ \st -> st {stateStats = mempty} _ <- m stats' <- gets stateStats modify $ \st -> st {stateStats = stats <> stats'} pure stats' -- | Can applications of this function be moved to device? isHostOnlyFun :: Name -> Grapher Bool isHostOnlyFun fn = asks $ S.member fn . envHostOnlyFuns -- | Get the 'Meta' corresponding to the current body. getMeta :: Grapher Meta getMeta = asks envMeta -- | Get the body depth of the current body (its nesting level). getBodyDepth :: Grapher BodyDepth getBodyDepth = asks (metaBodyDepth . envMeta)