-- |
-- This module implements program analysis to determine which program statements
-- the "Futhark.Optimise.ReduceDeviceSyncs" pass should move into 'GPUBody' kernels
-- to reduce blocking memory transfers between host and device. The results of
-- the analysis is encoded into a 'MigrationTable' which can be queried.
--
-- To reduce blocking scalar reads the module constructs a data flow
-- dependency graph of program variables (see
-- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph") in which
-- it finds a minimum vertex cut that separates array reads of scalars
-- from transitive usage that cannot or should not be migrated to
-- device.
--
-- The variables of each partition are assigned a 'MigrationStatus' that states
-- whether the computation of those variables should be moved to device or
-- remain on host. Due to how the graph is built and the vertex cut is found all
-- variables bound by a single statement will belong to the same partition.
--
-- The vertex cut contains all variables that will reside in device memory but
-- are required by host operations. These variables must be read from device
-- memory and cannot be reduced further in number merely by migrating
-- statements (subject to the accuracy of the graph model). The model is built
-- to reduce the worst-case number of scalar reads; an optimal migration of
-- statements depends on runtime data.
--
-- Blocking scalar writes are reduced by either turning such writes into
-- asynchronous kernels, as is done with scalar array literals and accumulator
-- updates, or by transforming host-device writing into device-device copying.
--
-- For details on how the graph is constructed and how the vertex cut is found,
-- see the master thesis "Reducing Synchronous GPU Memory Transfers" by Philip
-- Børgesen (2022).
module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
  ( -- * Analysis
    analyseProg,

    -- * Types
    MigrationTable,
    MigrationStatus (..),

    -- * Query

    -- | These functions all assume that no parent statement should be migrated.
    -- That is @shouldMoveStm stm mt@ should return @False@ for every statement
    -- @stm@ with a body that a queried 'VName' or 'Stm' is nested within,
    -- otherwise the query result may be invalid.
    shouldMoveStm,
    shouldMove,
    usedOnHost,
    statusOf,
  )
where

import Control.Monad
import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Control.Parallel.Strategies (parMap, rpar)
import Data.Bifunctor (first, second)
import Data.Foldable
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, fromMaybe, isJust, isNothing)
import qualified Data.Sequence as SQ
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
  ( EdgeType (..),
    Edges (..),
    Id,
    IdSet,
    Result (..),
    Routing (..),
    Vertex (..),
  )
import qualified Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph as MG

--------------------------------------------------------------------------------
--                              MIGRATION TABLES                              --
--------------------------------------------------------------------------------

-- | Where the value bound by a name should be computed.
data MigrationStatus
  = -- | The statement that computes the value should be moved to device.
    -- No host usage of the value will be left after the migration.
    MoveToDevice
  | -- | As 'MoveToDevice' but host usage of the value will remain after
    -- migration.
    UsedOnHost
  | -- | The statement that computes the value should remain on host.
    StayOnHost
  deriving (MigrationStatus -> MigrationStatus -> Bool
(MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> Eq MigrationStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c== :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
Eq MigrationStatus
-> (MigrationStatus -> MigrationStatus -> Ordering)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> Ord MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
>= :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c< :: MigrationStatus -> MigrationStatus -> Bool
compare :: MigrationStatus -> MigrationStatus -> Ordering
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
$cp1Ord :: Eq MigrationStatus
Ord, Int -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
(Int -> MigrationStatus -> ShowS)
-> (MigrationStatus -> String)
-> ([MigrationStatus] -> ShowS)
-> Show MigrationStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MigrationStatus] -> ShowS
$cshowList :: [MigrationStatus] -> ShowS
show :: MigrationStatus -> String
$cshow :: MigrationStatus -> String
showsPrec :: Int -> MigrationStatus -> ShowS
$cshowsPrec :: Int -> MigrationStatus -> ShowS
Show)

-- | Identifies
--
--     (1) which statements should be moved from host to device to reduce the
--         worst case number of blocking memory transfers.
--
--     (2) which migrated variables that still will be used on the host after
--         all such statements have been moved.
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)

-- | Where should the value bound by this name be computed?
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
  MigrationStatus -> Maybe MigrationStatus -> MigrationStatus
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost (Maybe MigrationStatus -> MigrationStatus)
-> Maybe MigrationStatus -> MigrationStatus
forall a b. (a -> b) -> a -> b
$ Int -> IntMap MigrationStatus -> Maybe MigrationStatus
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) IntMap MigrationStatus
mt

-- | Should this whole statement be moved from host to device?
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| (SubExp -> Bool) -> Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
  where
    movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
    movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (If (Var VName
n) Body GPU
_ Body GPU
_ IfDec (BranchType GPU)
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_) Body GPU
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
-- BasicOp and Apply statements might not bind any variables (shouldn't happen).
-- If statements might use a constant branch condition.
-- For loop statements might use a constant number of iterations.
-- HostOp statements cannot execute on device.
-- WithAcc statements are never moved in their entirety.
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False

-- | Should the value bound by this name be computed on device?
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost

-- | Will the value bound by this name be used on host?
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice

-- | Merges two migration tables that are assumed to be disjoint.
merge :: MigrationTable -> MigrationTable -> MigrationTable
merge :: MigrationTable -> MigrationTable -> MigrationTable
merge (MigrationTable IntMap MigrationStatus
a) (MigrationTable IntMap MigrationStatus
b) = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a IntMap MigrationStatus
-> IntMap MigrationStatus -> IntMap MigrationStatus
forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)

--------------------------------------------------------------------------------
--                         HOST-ONLY FUNCTION ANALYSIS                        --
--------------------------------------------------------------------------------

-- | Identifies top-level function definitions that cannot be run on the
-- device. The application of any such function is host-only.
type HostOnlyFuns = Set Name

-- | Returns the names of all top-level functions that cannot be called from the
-- device. The evaluation of such a function is host-only.
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
  let names :: [Name]
names = (FunDef GPU -> Name) -> [FunDef GPU] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Name
forall rep. FunDef rep -> Name
funDefName [FunDef GPU]
funs
      call_map :: Map Name (Maybe HostOnlyFuns)
call_map = [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns))
-> [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ [Name] -> [Maybe HostOnlyFuns] -> [(Name, Maybe HostOnlyFuns)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names ((FunDef GPU -> Maybe HostOnlyFuns)
-> [FunDef GPU] -> [Maybe HostOnlyFuns]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
   in [Name] -> HostOnlyFuns
forall a. Ord a => [a] -> Set a
S.fromList [Name]
names HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Ord a => Set a -> Set a -> Set a
\\ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
  where
    keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = [Name] -> HostOnlyFuns
forall a. Eq a => [a] -> Set a
S.fromAscList ([Name] -> HostOnlyFuns)
-> (Map Name a -> [Name]) -> Map Name a -> HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name a -> [Name]
forall k a. Map k a -> [k]
M.keys

    removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
      let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = (Maybe HostOnlyFuns -> Bool)
-> Map Name (Maybe HostOnlyFuns)
-> (Map Name (Maybe HostOnlyFuns), Map Name (Maybe HostOnlyFuns))
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition Maybe HostOnlyFuns -> Bool
forall a. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
       in if Map Name (Maybe HostOnlyFuns) -> Bool
forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
            then Map Name (Maybe HostOnlyFuns)
cm'
            else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns))
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ (Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'

    isHostOnly :: Maybe a -> Bool
isHostOnly = Maybe a -> Bool
forall a. Maybe a -> Bool
isNothing

    -- A function that calls a host-only function is itself host-only.
    checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
      | Set a
hostOnlyFuns Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
          Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
calls
    checkCalls Set a
_ Maybe (Set a)
_ =
      Maybe (Set a)
forall a. Maybe a
Nothing

-- | 'checkFunDef' returns 'Nothing' if this function definition uses arrays or
-- HostOps. Otherwise it returns the names of all applied functions, which may
-- include user defined functions that could turn out to be host-only.
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
  [Param DeclType] -> Maybe ()
checkFParams (FunDef GPU -> [FParam GPU]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun)
  [TypeBase ExtShape Uniqueness] -> Maybe ()
forall u. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fun)
  Body GPU -> Maybe HostOnlyFuns
checkBody (FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fun)
  where
    hostOnly :: Maybe a
hostOnly = Maybe a
forall a. Maybe a
Nothing
    ok :: Maybe ()
ok = () -> Maybe ()
forall a. a -> Maybe a
Just ()
    check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then Maybe ()
forall a. Maybe a
hostOnly else Maybe ()
ok

    checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = (Param DeclType -> Bool) -> [Param DeclType] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray

    checkLParams :: [(Param DeclType, b)] -> Maybe ()
checkLParams = ((Param DeclType, b) -> Bool) -> [(Param DeclType, b)] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray (Param DeclType -> Bool)
-> ((Param DeclType, b) -> Param DeclType)
-> (Param DeclType, b)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, b) -> Param DeclType
forall a b. (a, b) -> a
fst)

    checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = (TypeBase ExtShape u -> Bool) -> [TypeBase ExtShape u] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check TypeBase ExtShape u -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType

    checkPats :: [PatElem Type] -> Maybe ()
checkPats = (PatElem Type -> Bool) -> [PatElem Type] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check PatElem Type -> Bool
forall t. Typed t => t -> Bool
isArray

    checkLoopForm :: LoopForm rep -> Maybe ()
checkLoopForm (ForLoop VName
_ IntType
_ SubExp
_ ((LParam rep, VName)
_ : [(LParam rep, VName)]
_)) = Maybe ()
forall a. Maybe a
hostOnly
    checkLoopForm LoopForm rep
_ = Maybe ()
ok

    checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms (Stms GPU -> Maybe HostOnlyFuns)
-> (Body GPU -> Stms GPU) -> Body GPU -> Maybe HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms

    checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = Seq HostOnlyFuns -> HostOnlyFuns
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (Seq HostOnlyFuns -> HostOnlyFuns)
-> Maybe (Seq HostOnlyFuns) -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Maybe HostOnlyFuns)
-> Stms GPU -> Maybe (Seq HostOnlyFuns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms

    checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem Type] -> Maybe ()
checkPats [PatElem Type]
[PatElem (LetDec GPU)]
pats Maybe () -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e

    -- Any expression that produces an array is caught by checkPats
    checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (Op Op GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just (Name -> HostOnlyFuns
forall a. a -> Set a
S.singleton Name
fn)
    checkExp (If SubExp
_ Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) = do
      HostOnlyFuns
calls1 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
tbranch
      HostOnlyFuns
calls2 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
fbranch
      HostOnlyFuns -> Maybe HostOnlyFuns
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOnlyFuns
calls1 HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns
calls2)
    checkExp (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
      [(Param DeclType, SubExp)] -> Maybe ()
forall b. [(Param DeclType, b)] -> Maybe ()
checkLParams [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
      LoopForm GPU -> Maybe ()
forall rep. LoopForm rep -> Maybe ()
checkLoopForm LoopForm GPU
lform
      Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
    checkExp Exp GPU
_ = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just HostOnlyFuns
forall a. Set a
S.empty

--------------------------------------------------------------------------------
--                             MIGRATION ANALYSIS                             --
--------------------------------------------------------------------------------

-- | HostUsage identifies scalar variables that are used on host.
type HostUsage = [Id]

nameToId :: VName -> Id
nameToId :: VName -> Int
nameToId = VName -> Int
baseTag

-- | Analyses a program to return a migration table that covers all its
-- statements and variables.
analyseProg :: Prog GPU -> MigrationTable
analyseProg :: Prog GPU -> MigrationTable
analyseProg (Prog Stms GPU
consts [FunDef GPU]
funs) =
  let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs
      mt :: MigrationTable
mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts
      mts :: [MigrationTable]
mts = Strategy MigrationTable
-> (FunDef GPU -> MigrationTable)
-> [FunDef GPU]
-> [MigrationTable]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy MigrationTable
forall a. Strategy a
rpar (HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof) [FunDef GPU]
funs
   in (MigrationTable -> MigrationTable -> MigrationTable)
-> MigrationTable -> [MigrationTable] -> MigrationTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' MigrationTable -> MigrationTable -> MigrationTable
merge MigrationTable
mt [MigrationTable]
mts

-- | Analyses top-level constants.
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
  let usage :: [Int]
usage = ([Int] -> VName -> NameInfo GPU -> [Int])
-> [Int] -> Map VName (NameInfo GPU) -> [Int]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall t. Typed t => Names -> [Int] -> VName -> t -> [Int]
f (Names -> [Int] -> VName -> NameInfo GPU -> [Int])
-> Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall a b. (a -> b) -> a -> b
$ [FunDef GPU] -> Names
forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (Stms GPU -> Map VName (NameInfo GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
   in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
consts
  where
    f :: Names -> [Int] -> VName -> t -> [Int]
f Names
free [Int]
usage VName
n t
t
      | t -> Bool
forall t. Typed t => t -> Bool
isScalar t
t,
        VName
n VName -> Names -> Bool
`nameIn` Names
free =
          VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
      | Bool
otherwise =
          [Int]
usage

-- | Analyses a top-level function definition.
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
  let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
      usage :: [Int]
usage = ([Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int])
-> [Int] -> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int]
forall shape u. [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [] ([(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int])
-> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
-> [TypeBase ExtShape Uniqueness]
-> [(SubExpRes, TypeBase ExtShape Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body) (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fd)
      stms :: Stms GPU
stms = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body
   in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms
  where
    f :: [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [Int]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | TypeBase shape u -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
    f [Int]
usage (SubExpRes, TypeBase shape u)
_ = [Int]
usage

-- | Analyses statements. The 'HostUsage' list identifies which bound scalar
-- variables that subsequently may be used on host. All free variables such as
-- constants and function parameters are assumed to reside on host.
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
  let (Graph
g, Sources
srcs, [Int]
_) = HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms
      ([Int]
routed, [Int]
unrouted) = Sources
srcs
      ([Int]
_, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
unrouted Graph
g -- hereby routed
      f :: ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st' = Graph
-> ((IntSet, IntSet, IntSet)
    -> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> EdgeType
-> Int
-> ((IntSet, IntSet, IntSet), Visited ())
forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Int
-> (a, Visited ())
MG.fold Graph
g' (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet)
forall m.
(IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit ((IntSet, IntSet, IntSet), Visited ())
st' EdgeType
Normal
      st :: ((IntSet, IntSet, IntSet), Visited ())
st = (((IntSet, IntSet, IntSet), Visited ())
 -> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet)
initial, Visited ()
forall a. Visited a
MG.none) [Int]
unrouted
      (IntSet
vr, IntSet
vn, IntSet
tn) = ((IntSet, IntSet, IntSet), Visited ()) -> (IntSet, IntSet, IntSet)
forall a b. (a, b) -> a
fst (((IntSet, IntSet, IntSet), Visited ())
 -> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> (IntSet, IntSet, IntSet)
forall a b. (a -> b) -> a -> b
$ (((IntSet, IntSet, IntSet), Visited ())
 -> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st [Int]
routed
   in -- TODO: Delay reads into (deeper) branches

      IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus -> MigrationTable)
-> IntMap MigrationStatus -> MigrationTable
forall a b. (a -> b) -> a -> b
$
        [IntMap MigrationStatus] -> IntMap MigrationStatus
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
          [ (Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vr,
            (Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vn,
            -- Read by host if not reached by a reversed edge
            (Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) IntSet
tn
          ]
  where
    -- 1) Visited by reversed edge.
    -- 2) Visited by normal edge, no route.
    -- 3) Visited by normal edge, had route; will potentially be read by host.
    initial :: (IntSet, IntSet, IntSet)
initial = (IntSet
IS.empty, IntSet
IS.empty, IntSet
IS.empty)

    visit :: (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Reversed Vertex m
v =
      let vr' :: IntSet
vr' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vr
       in (IntSet
vr', IntSet
vn, IntSet
tn)
    visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
      let vn' :: IntSet
vn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vn
       in (IntSet
vr, IntSet
vn', IntSet
tn)
    visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal Vertex m
v =
      let tn' :: IntSet
tn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
tn
       in (IntSet
vr, IntSet
vn, IntSet
tn')

--------------------------------------------------------------------------------
--                                TYPE HELPERS                                --
--------------------------------------------------------------------------------

isScalar :: Typed t => t -> Bool
isScalar :: t -> Bool
isScalar = Type -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf

isScalarType :: TypeBase shape u -> Bool
isScalarType :: TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False

isArray :: Typed t => t -> Bool
isArray :: t -> Bool
isArray = Type -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf

isArrayType :: ArrayShape shape => TypeBase shape u -> Bool
isArrayType :: TypeBase shape u -> Bool
isArrayType = (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool)
-> (TypeBase shape u -> Int) -> TypeBase shape u -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank

--------------------------------------------------------------------------------
--                               GRAPH BUILDING                               --
--------------------------------------------------------------------------------

buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
  let (Graph
g, Sources
srcs, [Int]
sinks) = HostOnlyFuns -> Grapher () -> (Graph, Sources, [Int])
forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
      g' :: Graph
g' = (Graph -> Int -> Graph) -> Graph -> [Int] -> Graph
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g [Int]
usage
   in (Graph
g', Sources
srcs, [Int]
sinks)

-- | Graph a body.
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
  let res_ops :: IntSet
res_ops = Names -> IntSet
namesIntSet (Names -> IntSet) -> Names -> IntSet
forall a b. (a -> b) -> a -> b
$ [SubExpRes] -> Names
forall a. FreeIn a => a -> Names
freeIn (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
  BodyStats
body_stats <-
    Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> Grapher () -> Grapher BodyStats
forall a b. (a -> b) -> a -> b
$
      Grapher () -> Grapher ()
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher ()
tellOperands IntSet
res_ops)

  Int
body_depth <- (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int)
-> StateT State (Reader Env) Int -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Int
getBodyDepth
  let host_only :: Bool
host_only = Int -> IntSet -> Bool
IS.member Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
body_stats)
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        hops' :: IntSet
hops' = Int -> IntSet -> IntSet
IS.delete Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats)
        -- If body contains a variable that is required on host the parent
        -- statement that contains this body cannot be migrated as a whole.
        stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly :: Bool
bodyHostOnly = Bool
True} else BodyStats
stats
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats' {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
hops'}}

-- | Graph multiple statements.
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = (Stm GPU -> Grapher ()) -> Stms GPU -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm

-- | Graph a single statement.
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
  let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
  let e :: Exp GPU
e = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
  -- IMPORTANT! It is generally assumed that all scalars within types and
  -- shapes are present on host. Any expression of a type wherein one of its
  -- scalar operands appears must therefore ensure that that scalar operand is
  -- marked as a size variable (see the 'hostSize' function).
  case Exp GPU
e of
    BasicOp (SubExp SubExp
se) -> do
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
    BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
    BasicOp (ArrayLit [SubExp]
arr Type
t)
      | Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t,
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool) -> (SubExp -> Maybe VName) -> SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
          -- Migrating an array literal with free variables saves a write for
          -- every scalar it contains. Under some backends the compiler
          -- generates asynchronous writes for scalar constants but otherwise
          -- each write will be synchronous. If all scalars are constants then
          -- the compiler generates more efficient code that copies static
          -- device memory.
          Binding -> Grapher ()
graphAutoMove ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
    BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp Assert {} ->
      -- == OpenCL =============================================================
      --
      -- The next read after the execution of a kernel containing an assertion
      -- will be made asynchronous, followed by an asynchronous read to check
      -- if any assertion failed. The runtime will then block for all enqueued
      -- operations to finish.
      --
      -- Since an assertion only binds a certificate of unit type, an assertion
      -- cannot increase the number of (read) synchronizations that occur. In
      -- this regard it is free to migrate. The synchronization that does occur
      -- is however (presumably) more expensive as the pipeline of GPU work will
      -- be flushed.
      --
      -- Since this cost is difficult to quantify and amortize over assertion
      -- migration candidates (cost depends on ordering of kernels and reads) we
      -- assume it is insignificant. This will likely hold for a system where
      -- multiple threads or processes schedules GPU work, as system-wide
      -- throughput only will decrease if the GPU utilization decreases as a
      -- result.
      --
      -- == CUDA ===============================================================
      --
      -- Under the CUDA backend every read is synchronous and is followed by
      -- a full synchronization that blocks for all enqueued operations to
      -- finish. If any enqueued kernel contained an assertion, another
      -- synchronous read is then made to check if an assertion failed.
      --
      -- Migrating an assertion to save a read may thus introduce new reads, and
      -- the total number of reads can hence either decrease, remain the same,
      -- or even increase, subject to the ordering of reads and kernels that
      -- perform assertions.
      --
      -- Since it is possible to implement the same failure checking scheme as
      -- OpenCL using asynchronous reads (and doing so would be a good idea!)
      -- we consider this to be acceptable.
      --
      -- TODO: Implement the OpenCL failure checking scheme under CUDA. This
      --       should reduce the number of synchronizations per read to one.
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp (Index VName
_ Slice SubExp
slice)
      | Slice SubExp -> Bool
forall d. Slice d -> Bool
isFixed Slice SubExp
slice ->
          Binding -> Grapher ()
graphRead ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
    BasicOp {}
      | [(Int
_, Type
t)] <- [Binding]
bs,
        [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t,
        [SubExp]
dims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [], -- i.e. produces an array
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
          -- An expression that produces an array that only contains a single
          -- primitive value is as efficient to compute and copy as a scalar,
          -- and introduces no size variables.
          --
          -- This is an exception to the inefficiency rules that comes next.
          [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    -- Expressions with a cost sublinear to the size of their result arrays are
    -- risky to migrate as we cannot guarantee that their results are not
    -- returned from a GPUBody, which always copies its return values. Since
    -- this would make the effective asymptotic cost of such statements linear
    -- we block them from being migrated on their own.
    --
    -- The parent statement of an enclosing body may still be migrated as a
    -- whole given that each of its returned arrays either
    --   1) is backed by memory used by a migratable statement within its body.
    --   2) contains just a single element.
    -- An array matching either criterion is denoted "copyable memory" because
    -- the asymptotic cost of copying it is less than or equal to the statement
    -- that produced it. This makes the parent of statements with sublinear cost
    -- safe to migrate.
    BasicOp (Index VName
arr Slice SubExp
s) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Update Safety
_ VName
arr Slice SubExp
_ SubExp
_) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
      -- Migrating a FlatIndex leads to a memory allocation error.
      --
      -- TODO: Fix FlatIndex memory allocation error.
      --
      -- Can be replaced with 'graphHostOnly e' to disable migration.
      -- A fix can be verified by enabling tests/migration/reuse2_flatindex.fut
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Scratch PrimType
_ [SubExp]
s) ->
      -- Migrating a Scratch leads to a memory allocation error.
      --
      -- TODO: Fix Scratch memory allocation error.
      --
      -- Can be replaced with 'graphHostOnly e' to disable migration.
      -- A fix can be verified by enabling tests/migration/reuse4_scratch.fut
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
    BasicOp (Reshape ShapeChange SubExp
s VName
arr) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Rearrange [Int]
_ VName
arr) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Rotate [SubExp]
_ VName
arr) -> do
      -- Migrating a Rotate leads to a memory allocation error.
      --
      -- TODO: Fix Rotate memory allocation error.
      --
      -- Can be replaced with 'graphHostOnly e' to disable migration.
      -- A fix can be verified by enabling tests/migration/reuse7_rotate.fut
      [SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    -- Expressions with a cost linear to the size of their result arrays are
    -- inefficient to migrate into GPUBody kernels as such kernels are single-
    -- threaded. For sufficiently large arrays the cost may exceed what is saved
    -- by avoiding reads. We therefore also block these from being migrated,
    -- as well as their parents.
    BasicOp ArrayLit {} ->
      -- An array literal purely of primitive constants can be hoisted out to be
      -- a top-level constant, unless it is to be returned or consumed.
      -- Otherwise its runtime implementation will copy a precomputed static
      -- array and thus behave like a 'Copy'.
      -- Whether the rows are primitive constants or arrays, without any scalar
      -- variable operands such ArrayLit cannot directly prevent a scalar read.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Concat {} ->
      -- Is unlikely to prevent a scalar read as the only SubExp operand in
      -- practice is a computation of host-only size variables.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Copy {} ->
      -- Only takes an array operand, so cannot directly prevent a scalar read.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Manifest {} ->
      -- Takes no scalar operands so cannot directly prevent a scalar read.
      -- It is introduced as part of the BlkRegTiling kernel optimization and
      -- is thus unlikely to prevent the migration of a parent which was not
      -- already blocked by some host-only operation.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    -- END
    BasicOp UpdateAcc {} ->
      Binding -> Exp GPU -> Grapher ()
graphUpdateAcc ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs) Exp GPU
e
    Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
      Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
    If SubExp
cond Body GPU
tbody Body GPU
fbody IfDec (BranchType GPU)
_ ->
      [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody
    DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body ->
      [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body
    WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
      [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
    Op GPUBody {} ->
      -- A GPUBody can be migrated into a parent GPUBody by replacing it with
      -- its body statements and binding its return values inside 'ArrayLit's.
      Grapher ()
tellGPUBody
    Op Op GPU
_ ->
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
  where
    one :: [p] -> p
one [p
x] = p
x
    one [p]
_ = String -> p
forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."

    isFixed :: Slice d -> Bool
isFixed = Maybe [d] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [d] -> Bool) -> (Slice d -> Maybe [d]) -> Slice d -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice d -> Maybe [d]
forall d. Slice d -> Maybe [d]
sliceIndices

    -- new_dims may introduce new size variables which must be present on host
    -- when this expression is evaluated.
    graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
      (SubExp -> Grapher ()) -> t SubExp -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
      Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink

    hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
    hostSize SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    hostSizeVar :: VName -> Grapher ()
hostSizeVar = Int -> Grapher ()
requiredOnHost (Int -> Grapher ()) -> (VName -> Int) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Int
nameToId

-- | Bindings for all pattern elements bound by a statement.
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = (PatElem Type -> Binding) -> [PatElem Type] -> [Binding]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Int
nameToId VName
n, Type
t)) ([PatElem Type] -> [Binding])
-> (Stm GPU -> [PatElem Type]) -> Stm GPU -> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

-- | Graph a statement which in itself neither reads scalars from device memory
-- nor forces such scalars to be available on host. Such statement can be moved
-- to device to eliminate the host usage of its operands which transitively may
-- depend on a scalar device read.
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
  -- Only add vertices to the graph if they have a transitive dependency to
  -- an array read. Transitive dependencies through variables connected to
  -- sinks do not count.
  IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
  let edges :: Edges
edges = [Int] -> Edges
MG.declareEdges ((Binding -> Int) -> [Binding] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Binding -> Int
forall a b. (a, b) -> a
fst [Binding]
bs)
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) ((Binding -> Grapher ()) -> [Binding] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges Edges
edges IntSet
ops)

-- | Graph a statement that reads a scalar from device memory.
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
  -- Operands are not important as the source will block routes through b.
  Binding -> Grapher ()
addSource Binding
b
  Grapher ()
tellRead

-- | Graph a statement that always should be moved to device.
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
  -- Operands are not important as the source will block routes through b.
  Binding -> Grapher ()
addSource

-- | Graph a statement that is unfit for execution in a GPUBody and thus must
-- be executed on host, requiring all its operands to be made available there.
-- Parent statements of enclosing bodies are also blocked from being migrated.
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
  -- Connect the vertices of all operands to sinks to mark that they are
  -- required on host. Transitive reads that they depend upon can be delayed
  -- no further, and any parent statements cannot be migrated.
  IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
  Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
  Grapher ()
tellHostOnly

-- | Graph an 'UpdateAcc' statement.
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Int
_, Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) <- Binding
b =
  -- The actual graphing is delayed to the corrensponding 'WithAcc' parent.
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
        accs' :: IntMap [Delayed]
accs' = (Maybe [Delayed] -> Maybe [Delayed])
-> Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Int
nameToId VName
a) IntMap [Delayed]
accs
     in State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
accs'}
  where
    add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
    add (Just [Delayed]
xs) = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just ([Delayed] -> Maybe [Delayed]) -> [Delayed] -> Maybe [Delayed]
forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) Delayed -> [Delayed] -> [Delayed]
forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
  String -> Grapher ()
forall a. String -> a
compilerBugS
    String
"Type error: UpdateAcc did not produce accumulator typed value."

-- | Graph a function application.
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
  Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
  if Bool
hof
    then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e

-- | Graph an if statement.
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody = do
  Bool
body_host_only <-
    Grapher Bool -> Grapher Bool
forall a. Grapher a -> Grapher a
incForkDepthFor
      ( do
          BodyStats
tstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
tbody)
          BodyStats
fstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
fbody)
          Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Grapher Bool) -> Bool -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ BodyStats -> Bool
bodyHostOnly BodyStats
tstats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHostOnly BodyStats
fstats
      )

  -- Record aliases for copyable memory backing returned arrays.
  Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
tbody) (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
fbody)
  let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results

  IntSet
cond_id <- case (Bool
may_migrate, SubExp
cond) of
    (Bool
False, Var VName
n) ->
      -- The migration status of the condition is what determines whether the
      -- statement may be migrated as a whole or not. See 'shouldMoveStm'.
      Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n) Grapher () -> Grapher IntSet -> Grapher IntSet
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
    (Bool
True, Var VName
n) -> VName -> Grapher IntSet
onlyGraphedScalar VName
n
    (Bool
_, SubExp
_) -> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty

  IntSet -> Grapher ()
tellOperands IntSet
cond_id

  -- Connect branch results to bound variables to allow delaying reads out of
  -- branches. It might also be beneficial to move the whole statement to
  -- device, to avoid reading the branch condition value. This must be balanced
  -- against the need to read the values bound by the if statement.
  --
  -- By connecting the branch condition to each variable bound by the statement
  -- the condition will only stay on device if
  --
  --   (1) the if statement is not required on host, based on the statements
  --       within its body.
  --
  --   (2) no additional reads will be required to use the if statement bound
  --       variables should the whole statement be migrated.
  --
  -- If the condition is migrated to device and stays there, then the if
  -- statement must necessarily execute on device.
  --
  -- While the graph model built by this module generally migrates no more
  -- statements than necessary to obtain a minimum vertex cut, the branches
  -- of if statements are subject to an inaccuracy. Specifically model is not
  -- strong enough to capture their mutual exclusivity and thus encodes that
  -- both branches are taken. While this does not affect the resulting number
  -- of host-device reads it means that some reads may needlessly be delayed
  -- out of branches. The overhead as measured on futhark-benchmarks appears
  -- to be neglible though.
  [IntSet]
ret <- (SubExpRes -> SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
cond_id) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
tbody) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
fbody)
  ((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [IntSet]
ret)
  where
    results :: Body rep -> [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult

    comb :: IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
ci SubExpRes
a SubExpRes
b = (IntSet
ci IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<>) (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> Grapher IntSet
forall (t :: * -> *). Foldable t => t VName -> Grapher IntSet
onlyGraphedScalars (SubExpRes -> Set VName
toSet SubExpRes
a Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Set VName
toSet SubExpRes
b)

    toSet :: SubExpRes -> Set VName
toSet (SubExpRes Certs
_ (Var VName
n)) = VName -> Set VName
forall a. a -> Set a
S.singleton VName
n
    toSet SubExpRes
_ = Set VName
forall a. Set a
S.empty

-----------------------------------------------------
-- These type aliases are only used by 'graphLoop' --
-----------------------------------------------------
type ReachableBindings = IdSet

type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)

type NonExhausted = [Id]

type LoopValue = (Binding, Id, SubExp, SubExp)

-----------------------------------------------------
-----------------------------------------------------

-- | Graph a loop statement.
graphLoop ::
  [Binding] ->
  [(FParam GPU, SubExp)] ->
  LoopForm GPU ->
  Body GPU ->
  Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm GPU
_ Body GPU
_ =
  -- We expect each loop to bind a value or be eliminated.
  String -> Grapher ()
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body = do
  -- Graph loop params and body while capturing statistics.
  Graph
g0 <- Grapher Graph
getGraph
  BodyStats
stats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Int
subgraphId Int -> Grapher () -> Grapher ()
forall a. Int -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)

  -- Record aliases for copyable memory backing returned arrays.
  -- Does the loop return any arrays which prevent it from being migrated?
  let args :: [SubExp]
args = ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
  let results :: [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
  Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [SubExp]
args [SubExp]
results

  -- Connect loop condition to a sink if the loop cannot be migrated.
  -- The migration status of the condition is what determines whether the
  -- loop may be migrated as a whole or not. See 'shouldMoveStm'.
  let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
    ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_ -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
    WhileLoop VName
n -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
    LoopForm GPU
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  -- Connect graphed return values to their loop parameters.
  Graph
g1 <- Grapher Graph
getGraph
  (LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g1) [LoopValue]
loopValues

  -- Route the sources within the loop body in isolation.
  -- The loop graph must not be altered after this point.
  [Int]
srcs <- Int -> Grapher [Int]
routeSubgraph Int
subgraphId

  -- Graph the variables bound by the statement.
  [LoopValue] -> (LoopValue -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [LoopValue]
loopValues ((LoopValue -> Grapher ()) -> Grapher ())
-> (LoopValue -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Int
p, SubExp
_, SubExp
_) -> Binding -> IntSet -> Grapher ()
createNode Binding
bnd (Int -> IntSet
IS.singleton Int
p)

  -- If a device read is delayed from one iteration to the next the
  -- corresponding variables bound by the statement must be treated as
  -- sources.
  Graph
g2 <- Grapher Graph
getGraph
  let (IntSet
dbs, ReachableBindingsCache
rbc) = ((IntSet, ReachableBindingsCache)
 -> Int -> (IntSet, ReachableBindingsCache))
-> (IntSet, ReachableBindingsCache)
-> [Int]
-> (IntSet, ReachableBindingsCache)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g2) (IntSet
IS.empty, ReachableBindingsCache
forall a. Visited a
MG.none) [Int]
srcs
  (Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (IntSet -> [Int]
IS.toList IntSet
dbs [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<>)

  -- Connect operands to sinks if they can reach a sink within the loop.
  -- Otherwise connect them to the loop bound variables that they can
  -- reach and exhaust their normal entry edges into the loop.
  -- This means a read can be delayed through a loop but not into it if
  -- that would increase the number of reads done by any given iteration.
  let ops :: IntSet
ops = (Int -> Bool) -> IntSet -> IntSet
IS.filter (Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
`MG.member` Graph
g0) (BodyStats -> IntSet
bodyOperands BodyStats
stats)
  (ReachableBindingsCache
 -> Int -> StateT State (Reader Env) ReachableBindingsCache)
-> ReachableBindingsCache -> [Int] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (IntSet -> [Int]
IS.elems IntSet
ops)

  -- It might be beneficial to move the whole loop to device, to avoid
  -- reading the (initial) loop condition value. This must be balanced
  -- against the need to read the values bound by the loop statement.
  --
  -- For more details see the similar description for if statements.
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
    ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
_ ->
      SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
    WhileLoop VName
n
      | (Binding
_, Int
_, SubExp
arg, SubExp
_) <- VName -> LoopValue
loopValueFor VName
n ->
          SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
  where
    subgraphId :: Id
    subgraphId :: Int
subgraphId = Binding -> Int
forall a b. (a, b) -> a
fst Binding
b

    loopValues :: [LoopValue]
    loopValues :: [LoopValue]
loopValues =
      let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = [Binding]
-> [(Param DeclType, SubExp)]
-> [SubExpRes]
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
          tmp' :: [LoopValue]
tmp' = (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
 -> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue])
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue]
forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
 -> [LoopValue])
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b. (a -> b) -> a -> b
$
            \(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
              let i :: Int
i = VName -> Int
nameToId (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
               in (Binding
bnd, Int
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
       in (LoopValue -> Bool) -> [LoopValue] -> [LoopValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((Int
_, Type
t), Int
_, SubExp
_, SubExp
_) -> Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) [LoopValue]
tmp'

    bindings :: IdSet
    bindings :: IntSet
bindings = [Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> [Int] -> IntSet
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Int) -> [LoopValue] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\((Int
i, Type
_), Int
_, SubExp
_, SubExp
_) -> Int
i) [LoopValue]
loopValues

    loopValueFor :: VName -> LoopValue
    loopValueFor :: VName -> LoopValue
loopValueFor VName
n =
      Maybe LoopValue -> LoopValue
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe LoopValue -> LoopValue) -> Maybe LoopValue -> LoopValue
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Bool) -> [LoopValue] -> Maybe LoopValue
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Int
p, SubExp
_, SubExp
_) -> Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Int
nameToId VName
n) [LoopValue]
loopValues

    graphTheLoop :: Grapher ()
    graphTheLoop :: Grapher ()
graphTheLoop = do
      (LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ LoopValue -> Grapher ()
forall a d. ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam [LoopValue]
loopValues

      -- For simplicity we do not currently track memory reuse through merge
      -- parameters. A parameter does not simply reuse the memory of its
      -- argument; it must also consider the iteration return value, which in
      -- turn may depend on other merge parameters.
      --
      -- Situations that would benefit from this tracking is unlikely to occur
      -- at the time of writing, and if it occurs current compiler limitations
      -- will prevent successful compilation.
      -- Specifically it requires the merge parameter argument to reuse memory
      -- from an array literal, and both it and the loop must occur within an
      -- if statement branch. Array literals are generally hoisted out of if
      -- statements however, and when they are not, a memory allocation error
      -- occurs.
      --
      -- TODO: Track memory reuse through merge parameters.

      case LoopForm GPU
lform of
        ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
elems -> do
          SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IntSet -> Grapher ()
tellOperands
          ((Param Type, VName) -> Grapher ())
-> [(Param Type, VName)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param Type, VName) -> Grapher ()
forall dec. Typed dec => (Param dec, VName) -> Grapher ()
graphForInElem [(Param Type, VName)]
[(LParam GPU, VName)]
elems
        WhileLoop VName
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Body GPU -> Grapher ()
graphBody Body GPU
body
      where
        graphForInElem :: (Param dec, VName) -> Grapher ()
graphForInElem (Param dec
p, VName
arr) = do
          Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isScalar Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Binding -> Grapher ()
addSource (VName -> Int
nameToId (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p, Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p)
          Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isArray Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (VName -> Int
nameToId (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p), Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p) Binding -> VName -> Grapher ()
`reuses` VName
arr

        graphParam :: ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Int
p, SubExp
arg, d
_) =
          do
            -- It is unknown whether a read can be delayed via the parameter
            -- from one iteration to the next, so we have to create a vertex
            -- even if the initial value never depends on a read.
            Binding -> Grapher ()
addVertex (Int
p, Type
t)
            IntSet
ops <- SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg
            Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) IntSet
ops

    mergeLoopParam :: Graph -> LoopValue -> Grapher ()
    mergeLoopParam :: Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g (Binding
_, Int
p, SubExp
_, SubExp
res)
      | Var VName
n <- SubExp
res,
        Int
ret <- VName -> Int
nameToId VName
n,
        Int
ret Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
p =
          if Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
MG.isSinkConnected Int
p Graph
g
            then Int -> Grapher ()
connectToSink Int
ret
            else Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) (Int -> IntSet
IS.singleton Int
ret)
      | Bool
otherwise =
          () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    deviceBindings ::
      Graph ->
      (ReachableBindings, ReachableBindingsCache) ->
      Id ->
      (ReachableBindings, ReachableBindingsCache)
    deviceBindings :: Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g (IntSet
rb, ReachableBindingsCache
rbc) Int
i =
      let (Result IntSet
r, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
       in case Result IntSet
r of
            Produced IntSet
rb' -> (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', ReachableBindingsCache
rbc')
            Result IntSet
_ ->
              String -> (IntSet, ReachableBindingsCache)
forall a. String -> a
compilerBugS
                String
"Migration graph sink could be reached from source after it\
                \ had been attempted routed."
    bindingReach ::
      ReachableBindings ->
      EdgeType ->
      Vertex Meta ->
      ReachableBindings
    bindingReach :: IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach IntSet
rb EdgeType
_ Vertex Meta
v
      | Int
i <- Vertex Meta -> Int
forall m. Vertex m -> Int
vertexId Vertex Meta
v,
        Int -> IntSet -> Bool
IS.member Int
i IntSet
bindings =
          Int -> IntSet -> IntSet
IS.insert Int
i IntSet
rb
      | Bool
otherwise =
          IntSet
rb
    connectOperand ::
      ReachableBindingsCache ->
      Id ->
      Grapher ReachableBindingsCache
    connectOperand :: ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Int
op = do
      Graph
g <- Grapher Graph
getGraph
      case Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
op Graph
g of
        Maybe (Vertex Meta)
Nothing -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
        Just Vertex Meta
v ->
          case Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
            Edges
ToSink -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
            ToNodes IntSet
es Maybe IntSet
Nothing -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
es
            ToNodes IntSet
_ (Just IntSet
nx) -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
nx
      where
        connectOp ::
          Graph ->
          ReachableBindingsCache ->
          Id -> -- operand id
          IdSet -> -- its edges
          Grapher ReachableBindingsCache
        connectOp :: Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Int
i IntSet
es = do
          let (Result IntSet
res, [Int]
nx, ReachableBindingsCache
rbc') = Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
IS.empty, [], ReachableBindingsCache
rbc) (IntSet -> [Int]
IS.elems IntSet
es)
          case Result IntSet
res of
            Result IntSet
FoundSink -> Int -> Grapher ()
connectToSink Int
i
            Produced IntSet
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Vertex Meta -> Vertex Meta) -> Int -> Graph -> Graph
forall m. (Vertex m -> Vertex m) -> Int -> Graph m -> Graph m
MG.adjust ([Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb) Int
i
          ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
        updateEdges ::
          NonExhausted ->
          ReachableBindings ->
          Vertex Meta ->
          Vertex Meta
        updateEdges :: [Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb Vertex Meta
v
          | ToNodes IntSet
es Maybe IntSet
_ <- Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
              let nx' :: IntSet
nx' = [Int] -> IntSet
IS.fromList [Int]
nx
                  es' :: Edges
es' = IntSet -> Maybe IntSet -> Edges
ToNodes (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
es) (Maybe IntSet -> Edges) -> Maybe IntSet -> Edges
forall a b. (a -> b) -> a -> b
$ IntSet -> Maybe IntSet
forall a. a -> Maybe a
Just (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
nx')
               in Vertex Meta
v {vertexEdges :: Edges
vertexEdges = Edges
es'}
          | Bool
otherwise = Vertex Meta
v
        findBindings ::
          Graph ->
          (ReachableBindings, NonExhausted, ReachableBindingsCache) ->
          [Id] -> -- current non-exhausted edges
          (MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
        findBindings :: Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
_ (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) [] =
          (IntSet -> Result IntSet
forall a. a -> Result a
Produced IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc)
        findBindings Graph
g (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) (Int
i : [Int]
is)
          | Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
            Just Int
gid <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
            Int
gid Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
subgraphId -- only search the subgraph
            =
              let (Result IntSet
res, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
               in case Result IntSet
res of
                    Result IntSet
FoundSink -> (Result IntSet
forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
                    Produced IntSet
rb' -> Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', [Int]
nx, ReachableBindingsCache
rbc') [Int]
is
          | Bool
otherwise =
              -- don't exhaust
              Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb, Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
nx, ReachableBindingsCache
rbc) [Int]
is

-- | Graph a 'WithAcc' statement.
graphWithAcc ::
  [Binding] ->
  [WithAccInput GPU] ->
  Lambda GPU ->
  Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
  -- Graph the body, capturing 'UpdateAcc' statements for delayed graphing.
  Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)

  -- Graph each accumulator monoid and its associated 'UpdateAcc' statements.
  ((Type, WithAccInput GPU) -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type, WithAccInput GPU) -> Grapher ()
forall shape u a b.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph ([(Type, WithAccInput GPU)] -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Type] -> [WithAccInput GPU] -> [(Type, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs

  -- Record aliases for the backing memory of each returned array.
  -- 'WithAcc' statements are never migrated as a whole and always returns
  -- arrays backed by memory allocated elsewhere.
  let arrs :: [SubExp]
arrs = (WithAccInput GPU -> [SubExp]) -> [WithAccInput GPU] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
  let res :: [SubExpRes]
res = Int -> [SubExpRes] -> [SubExpRes]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
  Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
res)

  -- Connect return variables to bound values. No outgoing edge exists
  -- from an accumulator vertex so skip those. Note that accumulators do
  -- not map to returned arrays one-to-one but one-to-many.
  [IntSet]
ret <- (SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (SubExp -> Grapher IntSet)
-> (SubExpRes -> SubExp) -> SubExpRes -> Grapher IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
res
  ((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([(Binding, IntSet)] -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [Binding] -> [Binding]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
arrs) [Binding]
bs) [IntSet]
ret
  where
    graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a Shape
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
      let i :: Int
i = VName -> Int
nameToId VName
a

      [Delayed]
delayed <- [Delayed] -> Maybe [Delayed] -> [Delayed]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Delayed] -> [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
-> StateT State (Reader Env) [Delayed]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Maybe [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> IntMap [Delayed] -> Maybe [Delayed]
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i (IntMap [Delayed] -> Maybe [Delayed])
-> (State -> IntMap [Delayed]) -> State -> Maybe [Delayed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
      (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
i (State -> IntMap [Delayed]
stateUpdateAccs State
st)}

      Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
types ((Lambda GPU, [SubExp]) -> Lambda GPU
forall a b. (a, b) -> a
fst ((Lambda GPU, [SubExp]) -> Lambda GPU)
-> Maybe (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed

      -- Neutral elements must always be made available on host for 'WithAcc'
      -- to type check.
      (SubExp -> Grapher ()) -> [SubExp] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink ([SubExp] -> Grapher ()) -> [SubExp] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> ((Lambda GPU, [SubExp]) -> [SubExp])
-> Maybe (Lambda GPU, [SubExp])
-> [SubExp]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Lambda GPU, [SubExp]) -> [SubExp]
forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
    graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
      String -> Grapher ()
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."

-- Graph the operator and all 'UpdateAcc' statements associated with an
-- accumulator.
--
-- The arguments are the 'Id' for the accumulator token, the element types of
-- the accumulator/operator, its combining function if any, and all associated
-- 'UpdateAcc' statements outside kernels.
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) -- Only used on device.
graphAcc Int
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
  -- Accumulators are intended for use within SegOps but in principle the AST
  -- allows their 'UpdateAcc's to be used outside a kernel. This case handles
  -- that unlikely situation.

  Env
env <- Grapher Env
ask
  State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get

  -- Collect statistics about the operator statements.
  let lambda :: Lambda GPU
lambda = Lambda GPU -> Maybe (Lambda GPU) -> Lambda GPU
forall a. a -> Maybe a -> a
fromMaybe ([LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [] (BodyDec GPU -> Stms GPU -> [SubExpRes] -> Body GPU
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty []) []) Maybe (Lambda GPU)
op
  let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
  let stats :: BodyStats
stats = Reader Env BodyStats -> Env -> BodyStats
forall r a. Reader r a -> r -> a
R.runReader (Grapher BodyStats -> State -> Reader Env BodyStats
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
  -- We treat GPUBody kernels as host-only to not bother rewriting them inside
  -- operators and to simplify the analysis. They are unlikely to occur anyway.
  --
  -- NOTE: Performance may degrade if a GPUBody is replaced with its contents
  --       but the containing operator is used on host.
  let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats

  -- op operands are read from arrays and written back so if any of the operands
  -- are scalar then a read can be avoided by moving the UpdateAcc usages to
  -- device. If the op itself performs scalar reads its UpdateAcc usages should
  -- also be moved.
  let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall t. Typed t => t -> Bool
isScalar [Type]
types

  -- Determine which external variables the operator depends upon.
  -- 'bodyOperands' cannot be used as it might exclude operands that were
  -- connected to sinks within the body, so instead we create an artifical
  -- expression to capture graphed operands from.
  IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)

  case (Bool
host_only, Bool
does_read) of
    (Bool
True, Bool
_) -> do
      -- If the operator cannot run well in a GPUBody then all non-kernel
      -- UpdateAcc statements are host-only. The current analysis is ignorant
      -- of what happens inside kernels so we must assume that the operator
      -- is used within a kernel, meaning that we cannot migrate its statements.
      --
      -- TODO: Improve analysis if UpdateAcc ever is used outside kernels.
      (Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly (Exp GPU -> Grapher ())
-> (Delayed -> Exp GPU) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Exp GPU
forall a b. (a, b) -> b
snd) [Delayed]
delayed
      Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
    (Bool
_, Bool
True) -> do
      -- Migrate all accumulator usage to device to avoid reads and writes.
      (Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove (Binding -> Grapher ())
-> (Delayed -> Binding) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Binding
forall a b. (a, b) -> a
fst) [Delayed]
delayed
      Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
    (Bool, Bool)
_ -> do
      -- Only migrate operator and UpdateAcc statements if it can allow their
      -- operands to be migrated.
      Binding -> IntSet -> Grapher ()
createNode (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) IntSet
ops
      [Delayed] -> (Delayed -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed ((Delayed -> Grapher ()) -> Grapher ())
-> (Delayed -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$
        \(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> IntSet -> Grapher ()
createNode Binding
b (IntSet -> Grapher ())
-> (IntSet -> IntSet) -> IntSet -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IS.insert Int
i

-- Returns for an expression all scalar operands that must be made available
-- on host to execute the expression there.
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e =
  let is :: IntSet
is = (IntSet, Set VName) -> IntSet
forall a b. (a, b) -> a
fst ((IntSet, Set VName) -> IntSet) -> (IntSet, Set VName) -> IntSet
forall a b. (a -> b) -> a -> b
$ State (IntSet, Set VName) ()
-> (IntSet, Set VName) -> (IntSet, Set VName)
forall s a. State s a -> s -> s
execState (Exp GPU -> State (IntSet, Set VName) ()
collect Exp GPU
e) (IntSet, Set VName)
forall a. (IntSet, Set a)
initial
   in IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars
  where
    initial :: (IntSet, Set a)
initial = (IntSet
IS.empty, Set a
forall a. Set a
S.empty) -- scalar operands, accumulator tokens
    captureName :: VName -> StateT (p IntSet c) m ()
captureName VName
n = (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ())
-> (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((IntSet -> IntSet) -> p IntSet c -> p IntSet c)
-> (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall a b. (a -> b) -> a -> b
$ Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n)
    captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ())
-> (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall a b. (a -> b) -> a -> b
$ (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Set a -> Set a) -> p a (Set a) -> p a (Set a))
-> (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert a
a
    collectFree :: a -> StateT (p IntSet c) m ()
collectFree a
x = (VName -> StateT (p IntSet c) m ())
-> [VName] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)

    collect :: Exp GPU -> State (IntSet, Set VName) ()
collect b :: Exp GPU
b@BasicOp {} =
      Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
b
    collect (Apply Name
_ [(SubExp, Diet)]
params [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) =
      ((SubExp, Diet) -> State (IntSet, Set VName) ())
-> [(SubExp, Diet)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((SubExp, Diet) -> SubExp)
-> (SubExp, Diet)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
    collect (If SubExp
cond Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) =
      SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
cond State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
tbranch State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
fbranch
    collect (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
      ((FParam GPU, SubExp) -> State (IntSet, Set VName) ())
-> [(FParam GPU, SubExp)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((FParam GPU, SubExp) -> SubExp)
-> (FParam GPU, SubExp)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
      LoopForm GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
LoopForm rep -> StateT (p IntSet c) m ()
collectLForm LoopForm GPU
lform
      Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body
    collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
      [WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
    collect (Op Op GPU
op) =
      HostOp GPU (SOAC GPU) -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a rep c.
(Monad m, Bifunctor p, FreeIn a) =>
HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp Op GPU
HostOp GPU (SOAC GPU)
op

    collectBasic :: Exp rep -> StateT (p IntSet c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
      -- Writing a scalar to an array can be replaced with copying a single-
      -- element slice. If the scalar originates from device memory its read
      -- can thus be prevented without requiring the 'Update' to be migrated.
      Slice SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Slice SubExp
slice
    collectBasic (BasicOp (Replicate Shape
shape SubExp
_)) =
      -- The replicate of a scalar can be rewritten as a replicate of a single
      -- element array followed by a slice index.
      Shape -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Shape
shape
    collectBasic Exp rep
e' =
      -- Note: Plain VName values only refer to arrays.
      Walker rep (StateT (p IntSet c) m)
-> Exp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Walker rep (StateT (p IntSet c) m)
forall (m :: * -> *) rep. Monad m => Walker rep m
identityWalker {walkOnSubExp :: SubExp -> StateT (p IntSet c) m ()
walkOnSubExp = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp}) Exp rep
e'

    collectSubExp :: SubExp -> StateT (p IntSet c) m ()
collectSubExp (Var VName
n) = VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName VName
n
    collectSubExp SubExp
_ = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    collectBody :: Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body = do
      Stms GPU -> State (IntSet, Set VName) ()
collectStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
      [SubExpRes] -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
    collectStms :: Stms GPU -> State (IntSet, Set VName) ()
collectStms = (Stm GPU -> State (IntSet, Set VName) ())
-> Stms GPU -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> State (IntSet, Set VName) ()
collectStm

    collectStm :: Stm GPU -> State (IntSet, Set VName) ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
      | BasicOp UpdateAcc {} <- Exp GPU
ua,
        Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
        Acc VName
a Shape
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec GPU) -> Type
forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
          -- Capture the tokens of accumulators used on host.
          VName -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a a.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
ua
    collectStm Stm GPU
stm = Exp GPU -> State (IntSet, Set VName) ()
collect (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)

    collectLForm :: LoopForm rep -> StateT (p IntSet c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b [(LParam rep, VName)]
_) = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
b
    -- WhileLoop condition is declared as a loop parameter.
    collectLForm (WhileLoop VName
_) = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- The collective operands of an operator lambda body are only used on host
    -- if the associated accumulator is used in an UpdateAcc statement outside a
    -- kernel.
    collectWithAcc :: [WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
      Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
      Set VName
used_accs <- ((IntSet, Set VName) -> Set VName)
-> StateT (IntSet, Set VName) Identity (Set VName)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (IntSet, Set VName) -> Set VName
forall a b. (a, b) -> b
snd
      let accs :: [Type]
accs = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
      let used :: [Bool]
used = (Type -> Bool) -> [Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) -> VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
      ((Bool, WithAccInput GPU) -> State (IntSet, Set VName) ())
-> [(Bool, WithAccInput GPU)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc ([Bool] -> [WithAccInput GPU] -> [(Bool, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)

    collectAcc :: (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc (Bool
_, (Shape
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = () -> State (IntSet, Set VName) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    collectAcc (Bool
used, (Shape
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
      (SubExp -> State (IntSet, Set VName) ())
-> [SubExp] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
      Bool
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used (State (IntSet, Set VName) () -> State (IntSet, Set VName) ())
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op)

    -- Does not collect named operands in
    --
    --   * types and shapes; size variables are assumed available to the host.
    --
    --   * use by a kernel body.
    --
    -- All other operands are conservatively collected even if they generally
    -- appear to be size variables or results computed by a SizeOp.
    collectHostOp :: HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
    collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
      (SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
    collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
      (SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
    collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [HistOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
      (HistOp rep -> StateT (p IntSet c) m ())
-> [HistOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ HistOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p IntSet c) m ()
collectHistOp [HistOp rep]
ops
    collectHostOp (SizeOp SizeOp
op) = SizeOp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree SizeOp
op
    collectHostOp (OtherOp a
op) = a -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree a
op
    collectHostOp GPUBody {} = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    collectSegLevel :: SegLevel -> StateT (p IntSet c) m ()
collectSegLevel (SegThread (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
      SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size
    collectSegLevel (SegGroup (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
      SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size

    collectSegSpace :: SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
space =
      (SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    collectSegBinOp :: SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes Shape
_) =
      (SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes

    collectHistOp :: HistOp rep -> StateT (p IntSet c) m ()
collectHistOp (HistOp Shape
_ SubExp
rf [VName]
_ [SubExp]
nes Shape
_ Lambda rep
_) = do
      SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
rf
      (SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes

--------------------------------------------------------------------------------
--                        GRAPH BUILDING - PRIMITIVES                         --
--------------------------------------------------------------------------------

-- | Creates a vertex for the given binding, provided that the set of operands
-- is not empty.
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> IntSet -> Grapher ()
createNode Binding
b IntSet
ops =
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) (Binding -> Grapher ()
addVertex Binding
b Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge (Int -> Edges) -> Int -> Edges
forall a b. (a -> b) -> a -> b
$ Binding -> Int
forall a b. (a, b) -> a
fst Binding
b) IntSet
ops)

-- | Adds a vertex to the graph for the given binding.
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Int
i, Type
t) = do
  Meta
meta <- Grapher Meta
getMeta
  let v :: Vertex Meta
v = Int -> Meta -> Vertex Meta
forall m. Int -> m -> Vertex m
MG.vertex Int
i Meta
meta
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.insert Int
i)
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Meta -> Int
metaBodyDepth Meta
meta)
  (Graph -> Graph) -> Grapher ()
modifyGraph (Vertex Meta -> Graph -> Graph
forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)

-- | Adds a source connected vertex to the graph for the given binding.
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
  Binding -> Grapher ()
addVertex Binding
b
  (Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Binding -> Int
forall a b. (a, b) -> a
fst Binding
b Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:)

-- | Adds the given edges to each vertex identified by the 'IdSet'. It is
-- assumed that all vertices reside within the body that currently is being
-- graphed.
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
is = do
  (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g IntSet
is
  (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (IntSet -> IntSet -> IntSet
`IS.difference` IntSet
is)
addEdges Edges
es IntSet
is = do
  (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Graph -> Graph) -> Graph -> Int -> Graph)
-> (Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b. (a -> b) -> a -> b
$ Edges -> Int -> Graph -> Graph
forall m. Edges -> Int -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g IntSet
is
  IntSet -> Grapher ()
tellOperands IntSet
is

-- | Ensure that a variable (which is in scope) will be made available on host
-- before its first use.
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Int -> Grapher ()
requiredOnHost Int
i = do
  Maybe (Vertex Meta)
mv <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i (Graph -> Maybe (Vertex Meta))
-> Grapher Graph -> StateT State (Reader Env) (Maybe (Vertex Meta))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
  case Maybe (Vertex Meta)
mv of
    Maybe (Vertex Meta)
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just Vertex Meta
v -> do
      Int -> Grapher ()
connectToSink Int
i
      Int -> Grapher ()
tellHostOnlyParent (Meta -> Int
metaBodyDepth (Meta -> Int) -> Meta -> Int
forall a b. (a -> b) -> a -> b
$ Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v)

-- | Connects the vertex of the given id to a sink.
connectToSink :: Id -> Grapher ()
connectToSink :: Int -> Grapher ()
connectToSink Int
i = do
  (Graph -> Graph) -> Grapher ()
modifyGraph (Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink Int
i)
  (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.delete Int
i)

-- | Like 'connectToSink' but vertex is given by a t'SubExp'. This is a no-op if
-- the t'SubExp' is a constant.
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
connectSubExpToSink SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Routes all possible routes within the subgraph identified by this id.
-- Returns the ids of the source connected vertices that were attempted routed.
--
-- Assumption: The subgraph with the given id has just been created and no path
-- exists from it to an external sink.
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Int -> Grapher [Int]
routeSubgraph Int
si = do
  State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let g :: Graph
g = State -> Graph
stateGraph State
st
  let ([Int]
routed, [Int]
unrouted) = State -> Sources
stateSources State
st
  let ([Int]
gsrcs, [Int]
unrouted') = (Int -> Bool) -> [Int] -> Sources
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g) [Int]
unrouted
  let ([Int]
sinks, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
gsrcs Graph
g
  State -> Grapher ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State -> Grapher ()) -> State -> Grapher ()
forall a b. (a -> b) -> a -> b
$
    State
st
      { stateGraph :: Graph
stateGraph = Graph
g',
        stateSources :: Sources
stateSources = ([Int]
gsrcs [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
routed, [Int]
unrouted'),
        stateSinks :: [Int]
stateSinks = [Int]
sinks [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ State -> [Int]
stateSinks State
st
      }
  [Int] -> Grapher [Int]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
gsrcs

-- | @inSubGraph si g i@ returns whether @g@ contains a vertex with id @i@ that
-- is declared within the subgraph with id @si@.
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g Int
i
  | Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
    Just Int
mgi <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
      Int
si Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mgi
inSubGraph Int
_ Graph
_ Int
_ = Bool
False

-- | @b `reuses` n@ records that @b@ binds an array backed by the same memory
-- as @n@. If @b@ is not array typed or the backing memory is not copyable then
-- this does nothing.
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Int
i, Type
t) VName
n
  | Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t =
      do
        Maybe Int
body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
        case Maybe Int
body_depth of
          Just Int
bd -> Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd
          Maybe Int
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise =
      () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- @reusesReturn bs res@ records each array binding in @bs@ as reusing copyable
-- memory if the corresponding return value in @res@ is backed by copyable
-- memory.
--
-- If every array binding is registered as being backed by copyable memory then
-- the function returns @True@, otherwise it returns @False@.
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
  Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
  (Bool -> (Binding, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([Binding] -> [SubExp] -> [(Binding, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
  where
    reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
    reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
      | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
          -- Single element arrays are immediately recognizable as copyable so
          -- don't bother recording those. Note that this case also matches
          -- primitive return values.
          Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
      | (Int
i, Type
t) <- Binding
b,
        Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
        Var VName
n <- SubExp
se =
          do
            Maybe Int
res_body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
            case Maybe Int
res_body_depth of
              Just Int
inner -> do
                Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
                let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
                Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
              Maybe Int
_ ->
                Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      | Bool
otherwise =
          Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable

-- @reusesBranches bs b1 b2@ records each array binding in @bs@ as reusing
-- copyable memory if each corresponding return value in the lists @b1@ and @b2@
-- are backed by copyable memory.
--
-- If every array binding is registered as being backed by copyable memory then
-- the function returns @True@, otherwise it returns @False@.
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs [SubExp]
b1 [SubExp]
b2 = do
  Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
  (Bool -> (Binding, SubExp, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([(Binding, SubExp, SubExp)] -> Grapher Bool)
-> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ [Binding] -> [SubExp] -> [SubExp] -> [(Binding, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Binding]
bs [SubExp]
b1 [SubExp]
b2
  where
    reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
    reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se1, SubExp
se2)
      | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
          -- Single element arrays are immediately recognizable as copyable so
          -- don't bother recording those. Note that this case also matches
          -- primitive return values.
          Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
      | (Int
i, Type
t) <- Binding
b,
        Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
        Var VName
n1 <- SubExp
se1,
        Var VName
n2 <- SubExp
se2 =
          do
            Maybe Int
body_depth_1 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n1
            Maybe Int
body_depth_2 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n2
            case (Maybe Int
body_depth_1, Maybe Int
body_depth_2) of
              (Just Int
bd1, Just Int
bd2) -> do
                let inner :: Int
inner = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bd1 Int
bd2
                Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
                let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
                Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
              (Maybe Int, Maybe Int)
_ ->
                Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      | Bool
otherwise =
          Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable

--------------------------------------------------------------------------------
--                           GRAPH BUILDING - TYPES                           --
--------------------------------------------------------------------------------

type Grapher = StateT State (R.Reader Env)

data Env = Env
  { -- | See 'HostOnlyFuns'.
    Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
    -- | Metadata for the current body being graphed.
    Env -> Meta
envMeta :: Meta
  }

-- | A measurement of how many bodies something is nested within.
type BodyDepth = Int

-- | Metadata on the environment that a variable is declared within.
data Meta = Meta
  { -- | How many if statement branch bodies the variable binding is nested
    -- within. If a route passes through the edge u->v and the fork depth
    --
    --   1) increases from u to v, then u is within a conditional branch.
    --
    --   2) decreases from u to v, then v binds the result of two or more
    --      branches.
    --
    -- After the graph has been built and routed, this can be used to delay
    -- reads into deeper branches to reduce their likelihood of manifesting.
    Meta -> Int
metaForkDepth :: Int,
    -- | How many bodies the variable is nested within.
    Meta -> Int
metaBodyDepth :: BodyDepth,
    -- | An id for the subgraph within which the variable exists, defined at
    -- the body level. A read may only be delayed to a point within its own
    -- subgraph.
    Meta -> Maybe Int
metaGraphId :: Maybe Id
  }

-- | Ids for all variables used as an operand.
type Operands = IdSet

-- | Statistics on the statements within a body and their dependencies.
data BodyStats = BodyStats
  { -- | Whether the body contained any host-only statements.
    BodyStats -> Bool
bodyHostOnly :: Bool,
    -- | Whether the body contained any GPUBody kernels.
    BodyStats -> Bool
bodyHasGPUBody :: Bool,
    -- | Whether the body performed any reads.
    BodyStats -> Bool
bodyReads :: Bool,
    -- | All scalar variables represented in the graph that have been used
    -- as return values of the body or as operands within it, including those
    -- that are defined within the body itself. Variables with vertices
    -- connected to sinks may be excluded.
    BodyStats -> IntSet
bodyOperands :: Operands,
    -- | Depth of parent bodies with variables that are required on host. Since
    -- the variables are required on host, the parent statements of these bodies
    -- cannot be moved to device as a whole. They are host-only.
    BodyStats -> IntSet
bodyHostOnlyParents :: IS.IntSet
  }

instance Semigroup BodyStats where
  (BodyStats Bool
ho1 Bool
gb1 Bool
r1 IntSet
o1 IntSet
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 IntSet
o2 IntSet
hop2) =
    BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
      { bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
        bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
        bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
        bodyOperands :: IntSet
bodyOperands = IntSet -> IntSet -> IntSet
IS.union IntSet
o1 IntSet
o2,
        bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet -> IntSet -> IntSet
IS.union IntSet
hop1 IntSet
hop2
      }

instance Monoid BodyStats where
  mempty :: BodyStats
mempty =
    BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
      { bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
        bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
        bodyReads :: Bool
bodyReads = Bool
False,
        bodyOperands :: IntSet
bodyOperands = IntSet
IS.empty,
        bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
IS.empty
      }

type Graph = MG.Graph Meta

-- | All vertices connected from a source, partitioned into those that have
-- been attempted routed and those which have not.
type Sources = ([Id], [Id])

-- | All terminal vertices of routes.
type Sinks = [Id]

-- | A captured statement for which graphing has been delayed.
type Delayed = (Binding, Exp GPU)

-- | The vertex handle for a variable and its type.
type Binding = (Id, Type)

-- | Array variables backed by memory segments that may be copied, mapped to the
-- outermost known body depths that declares arrays backed by a superset of
-- those segments.
type CopyableMemoryMap = IM.IntMap BodyDepth

data State = State
  { -- | The graph being built.
    State -> Graph
stateGraph :: Graph,
    -- | All known scalars that have been graphed.
    State -> IntSet
stateGraphedScalars :: IdSet,
    -- | All variables that directly bind scalars read from device memory.
    State -> Sources
stateSources :: Sources,
    -- | Graphed scalars that are used as operands by statements that cannot be
    -- migrated. A read cannot be delayed beyond these, so if the statements
    -- that bind these variables are moved to device, the variables must be read
    -- from device memory.
    State -> [Int]
stateSinks :: Sinks,
    -- | Observed 'UpdateAcc' host statements to be graphed later.
    State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
    -- | A map of encountered arrays that are backed by copyable memory.
    -- Trivial instances such as single element arrays are excluded.
    State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
    -- | Information about the current body being graphed.
    State -> BodyStats
stateStats :: BodyStats
  }

--------------------------------------------------------------------------------
--                             GRAPHER OPERATIONS                             --
--------------------------------------------------------------------------------

execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof Grapher a
m =
  let s :: State
s = Reader Env State -> Env -> State
forall r a. Reader r a -> r -> a
R.runReader (Grapher a -> State -> Reader Env State
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
   in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Int]
stateSinks State
s)
  where
    env :: Env
env =
      Env :: HostOnlyFuns -> Meta -> Env
Env
        { envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
          envMeta :: Meta
envMeta =
            Meta :: Int -> Int -> Maybe Int -> Meta
Meta
              { metaForkDepth :: Int
metaForkDepth = Int
0,
                metaBodyDepth :: Int
metaBodyDepth = Int
0,
                metaGraphId :: Maybe Int
metaGraphId = Maybe Int
forall a. Maybe a
Nothing
              }
        }
    st :: State
st =
      State :: Graph
-> IntSet
-> Sources
-> [Int]
-> IntMap [Delayed]
-> CopyableMemoryMap
-> BodyStats
-> State
State
        { stateGraph :: Graph
stateGraph = Graph
forall m. Graph m
MG.empty,
          stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet
IS.empty,
          stateSources :: Sources
stateSources = ([], []),
          stateSinks :: [Int]
stateSinks = [],
          stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
forall a. IntMap a
IM.empty,
          stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap
forall a. IntMap a
IM.empty,
          stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty
        }

-- | Execute a computation in a modified environment.
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = (ReaderT Env Identity (a, State)
 -> ReaderT Env Identity (a, State))
-> Grapher a -> Grapher a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((Env -> Env)
-> ReaderT Env Identity (a, State)
-> ReaderT Env Identity (a, State)
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)

-- | Fetch the value of the environment.
ask :: Grapher Env
ask :: Grapher Env
ask = ReaderT Env Identity Env -> Grapher Env
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT Env Identity Env
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask

-- | Retrieve a function of the current environment.
asks :: (Env -> a) -> Grapher a
asks :: (Env -> a) -> Grapher a
asks = ReaderT Env Identity a -> Grapher a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT Env Identity a -> Grapher a)
-> ((Env -> a) -> ReaderT Env Identity a)
-> (Env -> a)
-> Grapher a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Env -> a) -> ReaderT Env Identity a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks

-- | Register that the body contains a host-only statement. This means its
-- parent statement and any parent bodies themselves are host-only. A host-only
-- statement should not be migrated, either because it cannot run on device or
-- because it would be inefficient to do so.
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHostOnly :: Bool
bodyHostOnly = Bool
True}}

-- | Register that the body contains a GPUBody kernel.
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
True}}

-- | Register that the current body contains a statement that reads device
-- memory.
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyReads :: Bool
bodyReads = Bool
True}}

-- | Register that these variables are used as operands within the current body.
tellOperands :: IdSet -> Grapher ()
tellOperands :: IntSet -> Grapher ()
tellOperands IntSet
is =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        operands :: IntSet
operands = BodyStats -> IntSet
bodyOperands BodyStats
stats
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyOperands :: IntSet
bodyOperands = IntSet
operands IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
is}}

-- | Register that the current statement with a body at the given body depth is
-- host-only.
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Int -> Grapher ()
tellHostOnlyParent Int
body_depth =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        parents :: IntSet
parents = BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats
        parents' :: IntSet
parents' = Int -> IntSet -> IntSet
IS.insert Int
body_depth IntSet
parents
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
parents'}}

-- | Get the graph under construction.
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = (State -> Graph) -> Grapher Graph
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph

-- | All scalar variables with a vertex representation in the graph.
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher IntSet
getGraphedScalars = (State -> IntSet) -> Grapher IntSet
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> IntSet
stateGraphedScalars

-- | Every known array that is backed by a memory segment that may be copied,
-- mapped to the outermost known body depth where an array is backed by a
-- superset of that segment.
--
-- A body where all returned arrays are backed by such memory and are written by
-- its own statements will retain its asymptotic cost if migrated as a whole.
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = (State -> CopyableMemoryMap) -> Grapher CopyableMemoryMap
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory

-- | The outermost known body depth for an array backed by the same copyable
-- memory as the array with this name.
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n = Int -> CopyableMemoryMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
nameToId VName
n) (CopyableMemoryMap -> Maybe Int)
-> Grapher CopyableMemoryMap -> Grapher (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory

-- | Reduces the variables to just the 'Id's of those that are scalars and which
-- have a vertex representation in the graph, excluding those that have been
-- connected to sinks.
onlyGraphedScalars :: Foldable t => t VName -> Grapher IdSet
onlyGraphedScalars :: t VName -> Grapher IntSet
onlyGraphedScalars t VName
vs = do
  let is :: IntSet
is = (IntSet -> VName -> IntSet) -> IntSet -> t VName -> IntSet
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\IntSet
s VName
n -> Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n) IntSet
s) IntSet
IS.empty t VName
vs
  IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars

-- | Like 'onlyGraphedScalars' but for a single 'VName'.
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher IntSet
onlyGraphedScalar VName
n = do
  let i :: Int
i = VName -> Int
nameToId VName
n
  IntSet
gss <- Grapher IntSet
getGraphedScalars
  if Int -> IntSet -> Bool
IS.member Int
i IntSet
gss
    then IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> IntSet
IS.singleton Int
i)
    else IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty

-- | Like 'onlyGraphedScalars' but for a single t'SubExp'.
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (Constant PrimValue
_) = IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher IntSet
onlyGraphedScalar VName
n

-- | Update the graph under construction.
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph :: Graph
stateGraph = Graph -> Graph
f (State -> Graph
stateGraph State
st)}

-- | Update the contents of the graphed scalar set.
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars IntSet -> IntSet
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet -> IntSet
f (State -> IntSet
stateGraphedScalars State
st)}

-- | Update the contents of the copyable memory map.
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap -> CopyableMemoryMap
f (State -> CopyableMemoryMap
stateCopyableMemory State
st)}

-- | Update the set of source connected vertices.
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources :: Sources
stateSources = Sources -> Sources
f (State -> Sources
stateSources State
st)}

-- | Record that this variable binds an array that is backed by copyable
-- memory shared by an array at this outermost body depth.
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd =
  (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (Int -> Int -> CopyableMemoryMap -> CopyableMemoryMap
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
i Int
bd)

-- | Increment the fork depth for variables graphed by this action.
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
        fork_depth :: Int
fork_depth = Meta -> Int
metaForkDepth Meta
meta
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaForkDepth :: Int
metaForkDepth = Int
fork_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}

-- | Increment the body depth for variables graphed by this action.
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
        body_depth :: Int
body_depth = Meta -> Int
metaBodyDepth Meta
meta
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaBodyDepth :: Int
metaBodyDepth = Int
body_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}

-- | Change the graph id for variables graphed by this action.
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: Int -> Grapher a -> Grapher a
graphIdFor Int
i =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaGraphId :: Maybe Int
metaGraphId = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i}}

-- | Capture body stats produced by the given action.
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
  BodyStats
stats <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty}

  a
_ <- Grapher a
m

  BodyStats
stats' <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
stats BodyStats -> BodyStats -> BodyStats
forall a. Semigroup a => a -> a -> a
<> BodyStats
stats'}

  BodyStats -> Grapher BodyStats
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'

-- | Can applications of this function be moved to device?
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = (Env -> Bool) -> Grapher Bool
forall a. (Env -> a) -> Grapher a
asks ((Env -> Bool) -> Grapher Bool) -> (Env -> Bool) -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ Name -> HostOnlyFuns -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fn (HostOnlyFuns -> Bool) -> (Env -> HostOnlyFuns) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns

-- | Get the 'Meta' corresponding to the current body.
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = (Env -> Meta) -> Grapher Meta
forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta

-- | Get the body depth of the current body (its nesting level).
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Int
getBodyDepth = (Env -> Int) -> StateT State (Reader Env) Int
forall a. (Env -> a) -> Grapher a
asks (Meta -> Int
metaBodyDepth (Meta -> Int) -> (Env -> Meta) -> Env -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)