module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
(
analyseProg,
MigrationTable,
MigrationStatus (..),
shouldMoveStm,
shouldMove,
usedOnHost,
statusOf,
)
where
import Control.Monad
import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Control.Parallel.Strategies (parMap, rpar)
import Data.Bifunctor (first, second)
import Data.Foldable
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, fromMaybe, isJust, isNothing)
import qualified Data.Sequence as SQ
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
( EdgeType (..),
Edges (..),
Id,
IdSet,
Result (..),
Routing (..),
Vertex (..),
)
import qualified Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph as MG
data MigrationStatus
=
MoveToDevice
|
UsedOnHost
|
StayOnHost
deriving (MigrationStatus -> MigrationStatus -> Bool
(MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> Eq MigrationStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c== :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
Eq MigrationStatus
-> (MigrationStatus -> MigrationStatus -> Ordering)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> Ord MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
>= :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c< :: MigrationStatus -> MigrationStatus -> Bool
compare :: MigrationStatus -> MigrationStatus -> Ordering
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
$cp1Ord :: Eq MigrationStatus
Ord, Int -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
(Int -> MigrationStatus -> ShowS)
-> (MigrationStatus -> String)
-> ([MigrationStatus] -> ShowS)
-> Show MigrationStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MigrationStatus] -> ShowS
$cshowList :: [MigrationStatus] -> ShowS
show :: MigrationStatus -> String
$cshow :: MigrationStatus -> String
showsPrec :: Int -> MigrationStatus -> ShowS
$cshowsPrec :: Int -> MigrationStatus -> ShowS
Show)
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
MigrationStatus -> Maybe MigrationStatus -> MigrationStatus
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost (Maybe MigrationStatus -> MigrationStatus)
-> Maybe MigrationStatus -> MigrationStatus
forall a b. (a -> b) -> a -> b
$ Int -> IntMap MigrationStatus -> Maybe MigrationStatus
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) IntMap MigrationStatus
mt
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| (SubExp -> Bool) -> Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
where
movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (If (Var VName
n) Body GPU
_ Body GPU
_ IfDec (BranchType GPU)
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice
merge :: MigrationTable -> MigrationTable -> MigrationTable
merge :: MigrationTable -> MigrationTable -> MigrationTable
merge (MigrationTable IntMap MigrationStatus
a) (MigrationTable IntMap MigrationStatus
b) = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a IntMap MigrationStatus
-> IntMap MigrationStatus -> IntMap MigrationStatus
forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)
type HostOnlyFuns = Set Name
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
let names :: [Name]
names = (FunDef GPU -> Name) -> [FunDef GPU] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Name
forall rep. FunDef rep -> Name
funDefName [FunDef GPU]
funs
call_map :: Map Name (Maybe HostOnlyFuns)
call_map = [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns))
-> [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ [Name] -> [Maybe HostOnlyFuns] -> [(Name, Maybe HostOnlyFuns)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names ((FunDef GPU -> Maybe HostOnlyFuns)
-> [FunDef GPU] -> [Maybe HostOnlyFuns]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
in [Name] -> HostOnlyFuns
forall a. Ord a => [a] -> Set a
S.fromList [Name]
names HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Ord a => Set a -> Set a -> Set a
\\ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
where
keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = [Name] -> HostOnlyFuns
forall a. Eq a => [a] -> Set a
S.fromAscList ([Name] -> HostOnlyFuns)
-> (Map Name a -> [Name]) -> Map Name a -> HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name a -> [Name]
forall k a. Map k a -> [k]
M.keys
removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = (Maybe HostOnlyFuns -> Bool)
-> Map Name (Maybe HostOnlyFuns)
-> (Map Name (Maybe HostOnlyFuns), Map Name (Maybe HostOnlyFuns))
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition Maybe HostOnlyFuns -> Bool
forall a. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
in if Map Name (Maybe HostOnlyFuns) -> Bool
forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
then Map Name (Maybe HostOnlyFuns)
cm'
else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns))
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ (Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'
isHostOnly :: Maybe a -> Bool
isHostOnly = Maybe a -> Bool
forall a. Maybe a -> Bool
isNothing
checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
| Set a
hostOnlyFuns Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
calls
checkCalls Set a
_ Maybe (Set a)
_ =
Maybe (Set a)
forall a. Maybe a
Nothing
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
[Param DeclType] -> Maybe ()
checkFParams (FunDef GPU -> [FParam GPU]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun)
[TypeBase ExtShape Uniqueness] -> Maybe ()
forall u. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fun)
Body GPU -> Maybe HostOnlyFuns
checkBody (FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fun)
where
hostOnly :: Maybe a
hostOnly = Maybe a
forall a. Maybe a
Nothing
ok :: Maybe ()
ok = () -> Maybe ()
forall a. a -> Maybe a
Just ()
check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then Maybe ()
forall a. Maybe a
hostOnly else Maybe ()
ok
checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = (Param DeclType -> Bool) -> [Param DeclType] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray
checkLParams :: [(Param DeclType, b)] -> Maybe ()
checkLParams = ((Param DeclType, b) -> Bool) -> [(Param DeclType, b)] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray (Param DeclType -> Bool)
-> ((Param DeclType, b) -> Param DeclType)
-> (Param DeclType, b)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, b) -> Param DeclType
forall a b. (a, b) -> a
fst)
checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = (TypeBase ExtShape u -> Bool) -> [TypeBase ExtShape u] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check TypeBase ExtShape u -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType
checkPats :: [PatElem Type] -> Maybe ()
checkPats = (PatElem Type -> Bool) -> [PatElem Type] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check PatElem Type -> Bool
forall t. Typed t => t -> Bool
isArray
checkLoopForm :: LoopForm rep -> Maybe ()
checkLoopForm (ForLoop VName
_ IntType
_ SubExp
_ ((LParam rep, VName)
_ : [(LParam rep, VName)]
_)) = Maybe ()
forall a. Maybe a
hostOnly
checkLoopForm LoopForm rep
_ = Maybe ()
ok
checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms (Stms GPU -> Maybe HostOnlyFuns)
-> (Body GPU -> Stms GPU) -> Body GPU -> Maybe HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms
checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = Seq HostOnlyFuns -> HostOnlyFuns
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (Seq HostOnlyFuns -> HostOnlyFuns)
-> Maybe (Seq HostOnlyFuns) -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Maybe HostOnlyFuns)
-> Stms GPU -> Maybe (Seq HostOnlyFuns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms
checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem Type] -> Maybe ()
checkPats [PatElem Type]
[PatElem (LetDec GPU)]
pats Maybe () -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e
checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Op Op GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just (Name -> HostOnlyFuns
forall a. a -> Set a
S.singleton Name
fn)
checkExp (If SubExp
_ Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) = do
HostOnlyFuns
calls1 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
tbranch
HostOnlyFuns
calls2 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
fbranch
HostOnlyFuns -> Maybe HostOnlyFuns
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOnlyFuns
calls1 HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns
calls2)
checkExp (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
[(Param DeclType, SubExp)] -> Maybe ()
forall b. [(Param DeclType, b)] -> Maybe ()
checkLParams [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
LoopForm GPU -> Maybe ()
forall rep. LoopForm rep -> Maybe ()
checkLoopForm LoopForm GPU
lform
Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
checkExp Exp GPU
_ = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just HostOnlyFuns
forall a. Set a
S.empty
type HostUsage = [Id]
nameToId :: VName -> Id
nameToId :: VName -> Int
nameToId = VName -> Int
baseTag
analyseProg :: Prog GPU -> MigrationTable
analyseProg :: Prog GPU -> MigrationTable
analyseProg (Prog Stms GPU
consts [FunDef GPU]
funs) =
let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs
mt :: MigrationTable
mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts
mts :: [MigrationTable]
mts = Strategy MigrationTable
-> (FunDef GPU -> MigrationTable)
-> [FunDef GPU]
-> [MigrationTable]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy MigrationTable
forall a. Strategy a
rpar (HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof) [FunDef GPU]
funs
in (MigrationTable -> MigrationTable -> MigrationTable)
-> MigrationTable -> [MigrationTable] -> MigrationTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' MigrationTable -> MigrationTable -> MigrationTable
merge MigrationTable
mt [MigrationTable]
mts
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
let usage :: [Int]
usage = ([Int] -> VName -> NameInfo GPU -> [Int])
-> [Int] -> Map VName (NameInfo GPU) -> [Int]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall t. Typed t => Names -> [Int] -> VName -> t -> [Int]
f (Names -> [Int] -> VName -> NameInfo GPU -> [Int])
-> Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall a b. (a -> b) -> a -> b
$ [FunDef GPU] -> Names
forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (Stms GPU -> Map VName (NameInfo GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
consts
where
f :: Names -> [Int] -> VName -> t -> [Int]
f Names
free [Int]
usage VName
n t
t
| t -> Bool
forall t. Typed t => t -> Bool
isScalar t
t,
VName
n VName -> Names -> Bool
`nameIn` Names
free =
VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
| Bool
otherwise =
[Int]
usage
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
usage :: [Int]
usage = ([Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int])
-> [Int] -> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int]
forall shape u. [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [] ([(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int])
-> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
-> [TypeBase ExtShape Uniqueness]
-> [(SubExpRes, TypeBase ExtShape Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body) (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fd)
stms :: Stms GPU
stms = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body
in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms
where
f :: [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [Int]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | TypeBase shape u -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
f [Int]
usage (SubExpRes, TypeBase shape u)
_ = [Int]
usage
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Int]
_) = HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms
([Int]
routed, [Int]
unrouted) = Sources
srcs
([Int]
_, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
unrouted Graph
g
f :: ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st' = Graph
-> ((IntSet, IntSet, IntSet)
-> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> EdgeType
-> Int
-> ((IntSet, IntSet, IntSet), Visited ())
forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Int
-> (a, Visited ())
MG.fold Graph
g' (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet)
forall m.
(IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit ((IntSet, IntSet, IntSet), Visited ())
st' EdgeType
Normal
st :: ((IntSet, IntSet, IntSet), Visited ())
st = (((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet)
initial, Visited ()
forall a. Visited a
MG.none) [Int]
unrouted
(IntSet
vr, IntSet
vn, IntSet
tn) = ((IntSet, IntSet, IntSet), Visited ()) -> (IntSet, IntSet, IntSet)
forall a b. (a, b) -> a
fst (((IntSet, IntSet, IntSet), Visited ())
-> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> (IntSet, IntSet, IntSet)
forall a b. (a -> b) -> a -> b
$ (((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st [Int]
routed
in
IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus -> MigrationTable)
-> IntMap MigrationStatus -> MigrationTable
forall a b. (a -> b) -> a -> b
$
[IntMap MigrationStatus] -> IntMap MigrationStatus
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
[ (Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vr,
(Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vn,
(Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) IntSet
tn
]
where
initial :: (IntSet, IntSet, IntSet)
initial = (IntSet
IS.empty, IntSet
IS.empty, IntSet
IS.empty)
visit :: (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Reversed Vertex m
v =
let vr' :: IntSet
vr' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vr
in (IntSet
vr', IntSet
vn, IntSet
tn)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
let vn' :: IntSet
vn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vn
in (IntSet
vr, IntSet
vn', IntSet
tn)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal Vertex m
v =
let tn' :: IntSet
tn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
tn
in (IntSet
vr, IntSet
vn, IntSet
tn')
isScalar :: Typed t => t -> Bool
isScalar :: t -> Bool
isScalar = Type -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isScalarType :: TypeBase shape u -> Bool
isScalarType :: TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False
isArray :: Typed t => t -> Bool
isArray :: t -> Bool
isArray = Type -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isArrayType :: ArrayShape shape => TypeBase shape u -> Bool
isArrayType :: TypeBase shape u -> Bool
isArrayType = (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool)
-> (TypeBase shape u -> Int) -> TypeBase shape u -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank
buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Int]
sinks) = HostOnlyFuns -> Grapher () -> (Graph, Sources, [Int])
forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
g' :: Graph
g' = (Graph -> Int -> Graph) -> Graph -> [Int] -> Graph
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g [Int]
usage
in (Graph
g', Sources
srcs, [Int]
sinks)
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
let res_ops :: IntSet
res_ops = Names -> IntSet
namesIntSet (Names -> IntSet) -> Names -> IntSet
forall a b. (a -> b) -> a -> b
$ [SubExpRes] -> Names
forall a. FreeIn a => a -> Names
freeIn (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
BodyStats
body_stats <-
Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> Grapher () -> Grapher BodyStats
forall a b. (a -> b) -> a -> b
$
Grapher () -> Grapher ()
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher ()
tellOperands IntSet
res_ops)
Int
body_depth <- (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int)
-> StateT State (Reader Env) Int -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Int
getBodyDepth
let host_only :: Bool
host_only = Int -> IntSet -> Bool
IS.member Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
body_stats)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
hops' :: IntSet
hops' = Int -> IntSet -> IntSet
IS.delete Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats)
stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly :: Bool
bodyHostOnly = Bool
True} else BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats' {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
hops'}}
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = (Stm GPU -> Grapher ()) -> Stms GPU -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
let e :: Exp GPU
e = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
case Exp GPU
e of
BasicOp (SubExp SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (ArrayLit [SubExp]
arr Type
t)
| Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool) -> (SubExp -> Maybe VName) -> SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
Binding -> Grapher ()
graphAutoMove ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp Assert {} ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
_ Slice SubExp
slice)
| Slice SubExp -> Bool
forall d. Slice d -> Bool
isFixed Slice SubExp
slice ->
Binding -> Grapher ()
graphRead ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
BasicOp {}
| [(Int
_, Type
t)] <- [Binding]
bs,
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t,
[SubExp]
dims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [],
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
arr Slice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Update Safety
_ VName
arr Slice SubExp
_ SubExp
_) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Scratch PrimType
_ [SubExp]
s) ->
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
BasicOp (Reshape ShapeChange SubExp
s VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rearrange [Int]
_ VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rotate [SubExp]
_ VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp ArrayLit {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Concat {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Copy {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Manifest {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp UpdateAcc {} ->
Binding -> Exp GPU -> Grapher ()
graphUpdateAcc ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs) Exp GPU
e
Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
If SubExp
cond Body GPU
tbody Body GPU
fbody IfDec (BranchType GPU)
_ ->
[Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody
DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body ->
[Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body
WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
[Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
Op GPUBody {} ->
Grapher ()
tellGPUBody
Op Op GPU
_ ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
where
one :: [p] -> p
one [p
x] = p
x
one [p]
_ = String -> p
forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."
isFixed :: Slice d -> Bool
isFixed = Maybe [d] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [d] -> Bool) -> (Slice d -> Maybe [d]) -> Slice d -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice d -> Maybe [d]
forall d. Slice d -> Maybe [d]
sliceIndices
graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
(SubExp -> Grapher ()) -> t SubExp -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink
hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
hostSize SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
hostSizeVar :: VName -> Grapher ()
hostSizeVar = Int -> Grapher ()
requiredOnHost (Int -> Grapher ()) -> (VName -> Int) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Int
nameToId
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = (PatElem Type -> Binding) -> [PatElem Type] -> [Binding]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Int
nameToId VName
n, Type
t)) ([PatElem Type] -> [Binding])
-> (Stm GPU -> [PatElem Type]) -> Stm GPU -> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
let edges :: Edges
edges = [Int] -> Edges
MG.declareEdges ((Binding -> Int) -> [Binding] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Binding -> Int
forall a b. (a, b) -> a
fst [Binding]
bs)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) ((Binding -> Grapher ()) -> [Binding] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges Edges
edges IntSet
ops)
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
Binding -> Grapher ()
addSource Binding
b
Grapher ()
tellRead
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
Binding -> Grapher ()
addSource
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
Grapher ()
tellHostOnly
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Int
_, Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) <- Binding
b =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
accs' :: IntMap [Delayed]
accs' = (Maybe [Delayed] -> Maybe [Delayed])
-> Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Int
nameToId VName
a) IntMap [Delayed]
accs
in State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
accs'}
where
add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
add (Just [Delayed]
xs) = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just ([Delayed] -> Maybe [Delayed]) -> [Delayed] -> Maybe [Delayed]
forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) Delayed -> [Delayed] -> [Delayed]
forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS
String
"Type error: UpdateAcc did not produce accumulator typed value."
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
if Bool
hof
then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody = do
Bool
body_host_only <-
Grapher Bool -> Grapher Bool
forall a. Grapher a -> Grapher a
incForkDepthFor
( do
BodyStats
tstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
tbody)
BodyStats
fstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
fbody)
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Grapher Bool) -> Bool -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ BodyStats -> Bool
bodyHostOnly BodyStats
tstats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHostOnly BodyStats
fstats
)
Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
tbody) (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
fbody)
let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results
IntSet
cond_id <- case (Bool
may_migrate, SubExp
cond) of
(Bool
False, Var VName
n) ->
Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n) Grapher () -> Grapher IntSet -> Grapher IntSet
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
(Bool
True, Var VName
n) -> VName -> Grapher IntSet
onlyGraphedScalar VName
n
(Bool
_, SubExp
_) -> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
IntSet -> Grapher ()
tellOperands IntSet
cond_id
[IntSet]
ret <- (SubExpRes -> SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
cond_id) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
tbody) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
fbody)
((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [IntSet]
ret)
where
results :: Body rep -> [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult
comb :: IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
ci SubExpRes
a SubExpRes
b = (IntSet
ci IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<>) (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> Grapher IntSet
forall (t :: * -> *). Foldable t => t VName -> Grapher IntSet
onlyGraphedScalars (SubExpRes -> Set VName
toSet SubExpRes
a Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Set VName
toSet SubExpRes
b)
toSet :: SubExpRes -> Set VName
toSet (SubExpRes Certs
_ (Var VName
n)) = VName -> Set VName
forall a. a -> Set a
S.singleton VName
n
toSet SubExpRes
_ = Set VName
forall a. Set a
S.empty
type ReachableBindings = IdSet
type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)
type NonExhausted = [Id]
type LoopValue = (Binding, Id, SubExp, SubExp)
graphLoop ::
[Binding] ->
[(FParam GPU, SubExp)] ->
LoopForm GPU ->
Body GPU ->
Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm GPU
_ Body GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body = do
Graph
g0 <- Grapher Graph
getGraph
BodyStats
stats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Int
subgraphId Int -> Grapher () -> Grapher ()
forall a. Int -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)
let args :: [SubExp]
args = ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
let results :: [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [SubExp]
args [SubExp]
results
let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_ -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
WhileLoop VName
n -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
LoopForm GPU
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Graph
g1 <- Grapher Graph
getGraph
(LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g1) [LoopValue]
loopValues
[Int]
srcs <- Int -> Grapher [Int]
routeSubgraph Int
subgraphId
[LoopValue] -> (LoopValue -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [LoopValue]
loopValues ((LoopValue -> Grapher ()) -> Grapher ())
-> (LoopValue -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Int
p, SubExp
_, SubExp
_) -> Binding -> IntSet -> Grapher ()
createNode Binding
bnd (Int -> IntSet
IS.singleton Int
p)
Graph
g2 <- Grapher Graph
getGraph
let (IntSet
dbs, ReachableBindingsCache
rbc) = ((IntSet, ReachableBindingsCache)
-> Int -> (IntSet, ReachableBindingsCache))
-> (IntSet, ReachableBindingsCache)
-> [Int]
-> (IntSet, ReachableBindingsCache)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g2) (IntSet
IS.empty, ReachableBindingsCache
forall a. Visited a
MG.none) [Int]
srcs
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (IntSet -> [Int]
IS.toList IntSet
dbs [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<>)
let ops :: IntSet
ops = (Int -> Bool) -> IntSet -> IntSet
IS.filter (Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
`MG.member` Graph
g0) (BodyStats -> IntSet
bodyOperands BodyStats
stats)
(ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache)
-> ReachableBindingsCache -> [Int] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (IntSet -> [Int]
IS.elems IntSet
ops)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
_ ->
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
WhileLoop VName
n
| (Binding
_, Int
_, SubExp
arg, SubExp
_) <- VName -> LoopValue
loopValueFor VName
n ->
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
where
subgraphId :: Id
subgraphId :: Int
subgraphId = Binding -> Int
forall a b. (a, b) -> a
fst Binding
b
loopValues :: [LoopValue]
loopValues :: [LoopValue]
loopValues =
let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = [Binding]
-> [(Param DeclType, SubExp)]
-> [SubExpRes]
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
tmp' :: [LoopValue]
tmp' = (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue])
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue]
forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue])
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b. (a -> b) -> a -> b
$
\(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
let i :: Int
i = VName -> Int
nameToId (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
in (Binding
bnd, Int
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
in (LoopValue -> Bool) -> [LoopValue] -> [LoopValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((Int
_, Type
t), Int
_, SubExp
_, SubExp
_) -> Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) [LoopValue]
tmp'
bindings :: IdSet
bindings :: IntSet
bindings = [Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> [Int] -> IntSet
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Int) -> [LoopValue] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\((Int
i, Type
_), Int
_, SubExp
_, SubExp
_) -> Int
i) [LoopValue]
loopValues
loopValueFor :: VName -> LoopValue
loopValueFor :: VName -> LoopValue
loopValueFor VName
n =
Maybe LoopValue -> LoopValue
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe LoopValue -> LoopValue) -> Maybe LoopValue -> LoopValue
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Bool) -> [LoopValue] -> Maybe LoopValue
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Int
p, SubExp
_, SubExp
_) -> Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Int
nameToId VName
n) [LoopValue]
loopValues
graphTheLoop :: Grapher ()
graphTheLoop :: Grapher ()
graphTheLoop = do
(LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ LoopValue -> Grapher ()
forall a d. ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam [LoopValue]
loopValues
case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
elems -> do
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IntSet -> Grapher ()
tellOperands
((Param Type, VName) -> Grapher ())
-> [(Param Type, VName)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param Type, VName) -> Grapher ()
forall dec. Typed dec => (Param dec, VName) -> Grapher ()
graphForInElem [(Param Type, VName)]
[(LParam GPU, VName)]
elems
WhileLoop VName
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Body GPU -> Grapher ()
graphBody Body GPU
body
where
graphForInElem :: (Param dec, VName) -> Grapher ()
graphForInElem (Param dec
p, VName
arr) = do
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isScalar Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Binding -> Grapher ()
addSource (VName -> Int
nameToId (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p, Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isArray Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (VName -> Int
nameToId (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p), Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p) Binding -> VName -> Grapher ()
`reuses` VName
arr
graphParam :: ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Int
p, SubExp
arg, d
_) =
do
Binding -> Grapher ()
addVertex (Int
p, Type
t)
IntSet
ops <- SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg
Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) IntSet
ops
mergeLoopParam :: Graph -> LoopValue -> Grapher ()
mergeLoopParam :: Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g (Binding
_, Int
p, SubExp
_, SubExp
res)
| Var VName
n <- SubExp
res,
Int
ret <- VName -> Int
nameToId VName
n,
Int
ret Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
p =
if Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
MG.isSinkConnected Int
p Graph
g
then Int -> Grapher ()
connectToSink Int
ret
else Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) (Int -> IntSet
IS.singleton Int
ret)
| Bool
otherwise =
() -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
deviceBindings ::
Graph ->
(ReachableBindings, ReachableBindingsCache) ->
Id ->
(ReachableBindings, ReachableBindingsCache)
deviceBindings :: Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g (IntSet
rb, ReachableBindingsCache
rbc) Int
i =
let (Result IntSet
r, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
in case Result IntSet
r of
Produced IntSet
rb' -> (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', ReachableBindingsCache
rbc')
Result IntSet
_ ->
String -> (IntSet, ReachableBindingsCache)
forall a. String -> a
compilerBugS
String
"Migration graph sink could be reached from source after it\
\ had been attempted routed."
bindingReach ::
ReachableBindings ->
EdgeType ->
Vertex Meta ->
ReachableBindings
bindingReach :: IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach IntSet
rb EdgeType
_ Vertex Meta
v
| Int
i <- Vertex Meta -> Int
forall m. Vertex m -> Int
vertexId Vertex Meta
v,
Int -> IntSet -> Bool
IS.member Int
i IntSet
bindings =
Int -> IntSet -> IntSet
IS.insert Int
i IntSet
rb
| Bool
otherwise =
IntSet
rb
connectOperand ::
ReachableBindingsCache ->
Id ->
Grapher ReachableBindingsCache
connectOperand :: ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Int
op = do
Graph
g <- Grapher Graph
getGraph
case Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
op Graph
g of
Maybe (Vertex Meta)
Nothing -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
Just Vertex Meta
v ->
case Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
Edges
ToSink -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
ToNodes IntSet
es Maybe IntSet
Nothing -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
es
ToNodes IntSet
_ (Just IntSet
nx) -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
nx
where
connectOp ::
Graph ->
ReachableBindingsCache ->
Id ->
IdSet ->
Grapher ReachableBindingsCache
connectOp :: Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Int
i IntSet
es = do
let (Result IntSet
res, [Int]
nx, ReachableBindingsCache
rbc') = Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
IS.empty, [], ReachableBindingsCache
rbc) (IntSet -> [Int]
IS.elems IntSet
es)
case Result IntSet
res of
Result IntSet
FoundSink -> Int -> Grapher ()
connectToSink Int
i
Produced IntSet
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Vertex Meta -> Vertex Meta) -> Int -> Graph -> Graph
forall m. (Vertex m -> Vertex m) -> Int -> Graph m -> Graph m
MG.adjust ([Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb) Int
i
ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
updateEdges ::
NonExhausted ->
ReachableBindings ->
Vertex Meta ->
Vertex Meta
updateEdges :: [Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb Vertex Meta
v
| ToNodes IntSet
es Maybe IntSet
_ <- Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
let nx' :: IntSet
nx' = [Int] -> IntSet
IS.fromList [Int]
nx
es' :: Edges
es' = IntSet -> Maybe IntSet -> Edges
ToNodes (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
es) (Maybe IntSet -> Edges) -> Maybe IntSet -> Edges
forall a b. (a -> b) -> a -> b
$ IntSet -> Maybe IntSet
forall a. a -> Maybe a
Just (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
nx')
in Vertex Meta
v {vertexEdges :: Edges
vertexEdges = Edges
es'}
| Bool
otherwise = Vertex Meta
v
findBindings ::
Graph ->
(ReachableBindings, NonExhausted, ReachableBindingsCache) ->
[Id] ->
(MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
findBindings :: Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
_ (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) [] =
(IntSet -> Result IntSet
forall a. a -> Result a
Produced IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc)
findBindings Graph
g (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) (Int
i : [Int]
is)
| Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
Just Int
gid <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
Int
gid Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
subgraphId
=
let (Result IntSet
res, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
in case Result IntSet
res of
Result IntSet
FoundSink -> (Result IntSet
forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
Produced IntSet
rb' -> Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', [Int]
nx, ReachableBindingsCache
rbc') [Int]
is
| Bool
otherwise =
Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb, Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
nx, ReachableBindingsCache
rbc) [Int]
is
graphWithAcc ::
[Binding] ->
[WithAccInput GPU] ->
Lambda GPU ->
Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
((Type, WithAccInput GPU) -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type, WithAccInput GPU) -> Grapher ()
forall shape u a b.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph ([(Type, WithAccInput GPU)] -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Type] -> [WithAccInput GPU] -> [(Type, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs
let arrs :: [SubExp]
arrs = (WithAccInput GPU -> [SubExp]) -> [WithAccInput GPU] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
let res :: [SubExpRes]
res = Int -> [SubExpRes] -> [SubExpRes]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
res)
[IntSet]
ret <- (SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (SubExp -> Grapher IntSet)
-> (SubExpRes -> SubExp) -> SubExpRes -> Grapher IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
res
((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([(Binding, IntSet)] -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [Binding] -> [Binding]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
arrs) [Binding]
bs) [IntSet]
ret
where
graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a Shape
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
let i :: Int
i = VName -> Int
nameToId VName
a
[Delayed]
delayed <- [Delayed] -> Maybe [Delayed] -> [Delayed]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Delayed] -> [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
-> StateT State (Reader Env) [Delayed]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Maybe [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> IntMap [Delayed] -> Maybe [Delayed]
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i (IntMap [Delayed] -> Maybe [Delayed])
-> (State -> IntMap [Delayed]) -> State -> Maybe [Delayed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
i (State -> IntMap [Delayed]
stateUpdateAccs State
st)}
Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
types ((Lambda GPU, [SubExp]) -> Lambda GPU
forall a b. (a, b) -> a
fst ((Lambda GPU, [SubExp]) -> Lambda GPU)
-> Maybe (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed
(SubExp -> Grapher ()) -> [SubExp] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink ([SubExp] -> Grapher ()) -> [SubExp] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> ((Lambda GPU, [SubExp]) -> [SubExp])
-> Maybe (Lambda GPU, [SubExp])
-> [SubExp]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Lambda GPU, [SubExp]) -> [SubExp]
forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
graphAcc Int
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
Env
env <- Grapher Env
ask
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let lambda :: Lambda GPU
lambda = Lambda GPU -> Maybe (Lambda GPU) -> Lambda GPU
forall a. a -> Maybe a -> a
fromMaybe ([LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [] (BodyDec GPU -> Stms GPU -> [SubExpRes] -> Body GPU
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty []) []) Maybe (Lambda GPU)
op
let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
let stats :: BodyStats
stats = Reader Env BodyStats -> Env -> BodyStats
forall r a. Reader r a -> r -> a
R.runReader (Grapher BodyStats -> State -> Reader Env BodyStats
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats
let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall t. Typed t => t -> Bool
isScalar [Type]
types
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)
case (Bool
host_only, Bool
does_read) of
(Bool
True, Bool
_) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly (Exp GPU -> Grapher ())
-> (Delayed -> Exp GPU) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Exp GPU
forall a b. (a, b) -> b
snd) [Delayed]
delayed
Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
(Bool
_, Bool
True) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove (Binding -> Grapher ())
-> (Delayed -> Binding) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Binding
forall a b. (a, b) -> a
fst) [Delayed]
delayed
Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
(Bool, Bool)
_ -> do
Binding -> IntSet -> Grapher ()
createNode (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) IntSet
ops
[Delayed] -> (Delayed -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed ((Delayed -> Grapher ()) -> Grapher ())
-> (Delayed -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$
\(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> IntSet -> Grapher ()
createNode Binding
b (IntSet -> Grapher ())
-> (IntSet -> IntSet) -> IntSet -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IS.insert Int
i
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e =
let is :: IntSet
is = (IntSet, Set VName) -> IntSet
forall a b. (a, b) -> a
fst ((IntSet, Set VName) -> IntSet) -> (IntSet, Set VName) -> IntSet
forall a b. (a -> b) -> a -> b
$ State (IntSet, Set VName) ()
-> (IntSet, Set VName) -> (IntSet, Set VName)
forall s a. State s a -> s -> s
execState (Exp GPU -> State (IntSet, Set VName) ()
collect Exp GPU
e) (IntSet, Set VName)
forall a. (IntSet, Set a)
initial
in IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars
where
initial :: (IntSet, Set a)
initial = (IntSet
IS.empty, Set a
forall a. Set a
S.empty)
captureName :: VName -> StateT (p IntSet c) m ()
captureName VName
n = (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ())
-> (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((IntSet -> IntSet) -> p IntSet c -> p IntSet c)
-> (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall a b. (a -> b) -> a -> b
$ Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n)
captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ())
-> (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall a b. (a -> b) -> a -> b
$ (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Set a -> Set a) -> p a (Set a) -> p a (Set a))
-> (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert a
a
collectFree :: a -> StateT (p IntSet c) m ()
collectFree a
x = (VName -> StateT (p IntSet c) m ())
-> [VName] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
collect :: Exp GPU -> State (IntSet, Set VName) ()
collect b :: Exp GPU
b@BasicOp {} =
Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
b
collect (Apply Name
_ [(SubExp, Diet)]
params [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) =
((SubExp, Diet) -> State (IntSet, Set VName) ())
-> [(SubExp, Diet)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((SubExp, Diet) -> SubExp)
-> (SubExp, Diet)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
collect (If SubExp
cond Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) =
SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
cond State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
tbranch State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
fbranch
collect (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
((FParam GPU, SubExp) -> State (IntSet, Set VName) ())
-> [(FParam GPU, SubExp)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((FParam GPU, SubExp) -> SubExp)
-> (FParam GPU, SubExp)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
LoopForm GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
LoopForm rep -> StateT (p IntSet c) m ()
collectLForm LoopForm GPU
lform
Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body
collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
[WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
collect (Op Op GPU
op) =
HostOp GPU (SOAC GPU) -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a rep c.
(Monad m, Bifunctor p, FreeIn a) =>
HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp Op GPU
HostOp GPU (SOAC GPU)
op
collectBasic :: Exp rep -> StateT (p IntSet c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
Slice SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Slice SubExp
slice
collectBasic (BasicOp (Replicate Shape
shape SubExp
_)) =
Shape -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Shape
shape
collectBasic Exp rep
e' =
Walker rep (StateT (p IntSet c) m)
-> Exp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Walker rep (StateT (p IntSet c) m)
forall (m :: * -> *) rep. Monad m => Walker rep m
identityWalker {walkOnSubExp :: SubExp -> StateT (p IntSet c) m ()
walkOnSubExp = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp}) Exp rep
e'
collectSubExp :: SubExp -> StateT (p IntSet c) m ()
collectSubExp (Var VName
n) = VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName VName
n
collectSubExp SubExp
_ = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectBody :: Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body = do
Stms GPU -> State (IntSet, Set VName) ()
collectStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
[SubExpRes] -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
collectStms :: Stms GPU -> State (IntSet, Set VName) ()
collectStms = (Stm GPU -> State (IntSet, Set VName) ())
-> Stms GPU -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> State (IntSet, Set VName) ()
collectStm
collectStm :: Stm GPU -> State (IntSet, Set VName) ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
| BasicOp UpdateAcc {} <- Exp GPU
ua,
Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
Acc VName
a Shape
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec GPU) -> Type
forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
VName -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a a.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
ua
collectStm Stm GPU
stm = Exp GPU -> State (IntSet, Set VName) ()
collect (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)
collectLForm :: LoopForm rep -> StateT (p IntSet c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b [(LParam rep, VName)]
_) = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
b
collectLForm (WhileLoop VName
_) = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectWithAcc :: [WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Set VName
used_accs <- ((IntSet, Set VName) -> Set VName)
-> StateT (IntSet, Set VName) Identity (Set VName)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (IntSet, Set VName) -> Set VName
forall a b. (a, b) -> b
snd
let accs :: [Type]
accs = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
let used :: [Bool]
used = (Type -> Bool) -> [Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) -> VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
((Bool, WithAccInput GPU) -> State (IntSet, Set VName) ())
-> [(Bool, WithAccInput GPU)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc ([Bool] -> [WithAccInput GPU] -> [(Bool, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)
collectAcc :: (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc (Bool
_, (Shape
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = () -> State (IntSet, Set VName) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectAcc (Bool
used, (Shape
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
(SubExp -> State (IntSet, Set VName) ())
-> [SubExp] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
Bool
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used (State (IntSet, Set VName) () -> State (IntSet, Set VName) ())
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op)
collectHostOp :: HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [HistOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(HistOp rep -> StateT (p IntSet c) m ())
-> [HistOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ HistOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p IntSet c) m ()
collectHistOp [HistOp rep]
ops
collectHostOp (SizeOp SizeOp
op) = SizeOp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree SizeOp
op
collectHostOp (OtherOp a
op) = a -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree a
op
collectHostOp GPUBody {} = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectSegLevel :: SegLevel -> StateT (p IntSet c) m ()
collectSegLevel (SegThread (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size
collectSegLevel (SegGroup (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size
collectSegSpace :: SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
space =
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
collectSegBinOp :: SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes Shape
_) =
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
collectHistOp :: HistOp rep -> StateT (p IntSet c) m ()
collectHistOp (HistOp Shape
_ SubExp
rf [VName]
_ [SubExp]
nes Shape
_ Lambda rep
_) = do
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
rf
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> IntSet -> Grapher ()
createNode Binding
b IntSet
ops =
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) (Binding -> Grapher ()
addVertex Binding
b Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge (Int -> Edges) -> Int -> Edges
forall a b. (a -> b) -> a -> b
$ Binding -> Int
forall a b. (a, b) -> a
fst Binding
b) IntSet
ops)
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Int
i, Type
t) = do
Meta
meta <- Grapher Meta
getMeta
let v :: Vertex Meta
v = Int -> Meta -> Vertex Meta
forall m. Int -> m -> Vertex m
MG.vertex Int
i Meta
meta
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.insert Int
i)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Meta -> Int
metaBodyDepth Meta
meta)
(Graph -> Graph) -> Grapher ()
modifyGraph (Vertex Meta -> Graph -> Graph
forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
Binding -> Grapher ()
addVertex Binding
b
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Binding -> Int
forall a b. (a, b) -> a
fst Binding
b Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:)
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g IntSet
is
(IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (IntSet -> IntSet -> IntSet
`IS.difference` IntSet
is)
addEdges Edges
es IntSet
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Graph -> Graph) -> Graph -> Int -> Graph)
-> (Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b. (a -> b) -> a -> b
$ Edges -> Int -> Graph -> Graph
forall m. Edges -> Int -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g IntSet
is
IntSet -> Grapher ()
tellOperands IntSet
is
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Int -> Grapher ()
requiredOnHost Int
i = do
Maybe (Vertex Meta)
mv <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i (Graph -> Maybe (Vertex Meta))
-> Grapher Graph -> StateT State (Reader Env) (Maybe (Vertex Meta))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
case Maybe (Vertex Meta)
mv of
Maybe (Vertex Meta)
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just Vertex Meta
v -> do
Int -> Grapher ()
connectToSink Int
i
Int -> Grapher ()
tellHostOnlyParent (Meta -> Int
metaBodyDepth (Meta -> Int) -> Meta -> Int
forall a b. (a -> b) -> a -> b
$ Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v)
connectToSink :: Id -> Grapher ()
connectToSink :: Int -> Grapher ()
connectToSink Int
i = do
(Graph -> Graph) -> Grapher ()
modifyGraph (Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink Int
i)
(IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.delete Int
i)
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
connectSubExpToSink SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Int -> Grapher [Int]
routeSubgraph Int
si = do
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let g :: Graph
g = State -> Graph
stateGraph State
st
let ([Int]
routed, [Int]
unrouted) = State -> Sources
stateSources State
st
let ([Int]
gsrcs, [Int]
unrouted') = (Int -> Bool) -> [Int] -> Sources
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g) [Int]
unrouted
let ([Int]
sinks, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
gsrcs Graph
g
State -> Grapher ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State -> Grapher ()) -> State -> Grapher ()
forall a b. (a -> b) -> a -> b
$
State
st
{ stateGraph :: Graph
stateGraph = Graph
g',
stateSources :: Sources
stateSources = ([Int]
gsrcs [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
routed, [Int]
unrouted'),
stateSinks :: [Int]
stateSinks = [Int]
sinks [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ State -> [Int]
stateSinks State
st
}
[Int] -> Grapher [Int]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
gsrcs
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g Int
i
| Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
Just Int
mgi <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
Int
si Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mgi
inSubGraph Int
_ Graph
_ Int
_ = Bool
False
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Int
i, Type
t) VName
n
| Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t =
do
Maybe Int
body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
case Maybe Int
body_depth of
Just Int
bd -> Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd
Maybe Int
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| Bool
otherwise =
() -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([Binding] -> [SubExp] -> [(Binding, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
where
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Int
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n <- SubExp
se =
do
Maybe Int
res_body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
case Maybe Int
res_body_depth of
Just Int
inner -> do
Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe Int
_ ->
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs [SubExp]
b1 [SubExp]
b2 = do
Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, SubExp, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([(Binding, SubExp, SubExp)] -> Grapher Bool)
-> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ [Binding] -> [SubExp] -> [SubExp] -> [(Binding, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Binding]
bs [SubExp]
b1 [SubExp]
b2
where
reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se1, SubExp
se2)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Int
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n1 <- SubExp
se1,
Var VName
n2 <- SubExp
se2 =
do
Maybe Int
body_depth_1 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n1
Maybe Int
body_depth_2 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n2
case (Maybe Int
body_depth_1, Maybe Int
body_depth_2) of
(Just Int
bd1, Just Int
bd2) -> do
let inner :: Int
inner = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bd1 Int
bd2
Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
(Maybe Int, Maybe Int)
_ ->
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
type Grapher = StateT State (R.Reader Env)
data Env = Env
{
Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
Env -> Meta
envMeta :: Meta
}
type BodyDepth = Int
data Meta = Meta
{
Meta -> Int
metaForkDepth :: Int,
Meta -> Int
metaBodyDepth :: BodyDepth,
Meta -> Maybe Int
metaGraphId :: Maybe Id
}
type Operands = IdSet
data BodyStats = BodyStats
{
BodyStats -> Bool
bodyHostOnly :: Bool,
BodyStats -> Bool
bodyHasGPUBody :: Bool,
BodyStats -> Bool
bodyReads :: Bool,
BodyStats -> IntSet
bodyOperands :: Operands,
BodyStats -> IntSet
bodyHostOnlyParents :: IS.IntSet
}
instance Semigroup BodyStats where
(BodyStats Bool
ho1 Bool
gb1 Bool
r1 IntSet
o1 IntSet
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 IntSet
o2 IntSet
hop2) =
BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
bodyOperands :: IntSet
bodyOperands = IntSet -> IntSet -> IntSet
IS.union IntSet
o1 IntSet
o2,
bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet -> IntSet -> IntSet
IS.union IntSet
hop1 IntSet
hop2
}
instance Monoid BodyStats where
mempty :: BodyStats
mempty =
BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
bodyReads :: Bool
bodyReads = Bool
False,
bodyOperands :: IntSet
bodyOperands = IntSet
IS.empty,
bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
IS.empty
}
type Graph = MG.Graph Meta
type Sources = ([Id], [Id])
type Sinks = [Id]
type Delayed = (Binding, Exp GPU)
type Binding = (Id, Type)
type CopyableMemoryMap = IM.IntMap BodyDepth
data State = State
{
State -> Graph
stateGraph :: Graph,
State -> IntSet
stateGraphedScalars :: IdSet,
State -> Sources
stateSources :: Sources,
State -> [Int]
stateSinks :: Sinks,
State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
State -> BodyStats
stateStats :: BodyStats
}
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof Grapher a
m =
let s :: State
s = Reader Env State -> Env -> State
forall r a. Reader r a -> r -> a
R.runReader (Grapher a -> State -> Reader Env State
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Int]
stateSinks State
s)
where
env :: Env
env =
Env :: HostOnlyFuns -> Meta -> Env
Env
{ envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
envMeta :: Meta
envMeta =
Meta :: Int -> Int -> Maybe Int -> Meta
Meta
{ metaForkDepth :: Int
metaForkDepth = Int
0,
metaBodyDepth :: Int
metaBodyDepth = Int
0,
metaGraphId :: Maybe Int
metaGraphId = Maybe Int
forall a. Maybe a
Nothing
}
}
st :: State
st =
State :: Graph
-> IntSet
-> Sources
-> [Int]
-> IntMap [Delayed]
-> CopyableMemoryMap
-> BodyStats
-> State
State
{ stateGraph :: Graph
stateGraph = Graph
forall m. Graph m
MG.empty,
stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet
IS.empty,
stateSources :: Sources
stateSources = ([], []),
stateSinks :: [Int]
stateSinks = [],
stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
forall a. IntMap a
IM.empty,
stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap
forall a. IntMap a
IM.empty,
stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty
}
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = (ReaderT Env Identity (a, State)
-> ReaderT Env Identity (a, State))
-> Grapher a -> Grapher a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((Env -> Env)
-> ReaderT Env Identity (a, State)
-> ReaderT Env Identity (a, State)
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)
ask :: Grapher Env
ask :: Grapher Env
ask = ReaderT Env Identity Env -> Grapher Env
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT Env Identity Env
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask
asks :: (Env -> a) -> Grapher a
asks :: (Env -> a) -> Grapher a
asks = ReaderT Env Identity a -> Grapher a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT Env Identity a -> Grapher a)
-> ((Env -> a) -> ReaderT Env Identity a)
-> (Env -> a)
-> Grapher a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Env -> a) -> ReaderT Env Identity a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHostOnly :: Bool
bodyHostOnly = Bool
True}}
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
True}}
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyReads :: Bool
bodyReads = Bool
True}}
tellOperands :: IdSet -> Grapher ()
tellOperands :: IntSet -> Grapher ()
tellOperands IntSet
is =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
operands :: IntSet
operands = BodyStats -> IntSet
bodyOperands BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyOperands :: IntSet
bodyOperands = IntSet
operands IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
is}}
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Int -> Grapher ()
tellHostOnlyParent Int
body_depth =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
parents :: IntSet
parents = BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats
parents' :: IntSet
parents' = Int -> IntSet -> IntSet
IS.insert Int
body_depth IntSet
parents
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
parents'}}
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = (State -> Graph) -> Grapher Graph
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher IntSet
getGraphedScalars = (State -> IntSet) -> Grapher IntSet
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> IntSet
stateGraphedScalars
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = (State -> CopyableMemoryMap) -> Grapher CopyableMemoryMap
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n = Int -> CopyableMemoryMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
nameToId VName
n) (CopyableMemoryMap -> Maybe Int)
-> Grapher CopyableMemoryMap -> Grapher (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory
onlyGraphedScalars :: Foldable t => t VName -> Grapher IdSet
onlyGraphedScalars :: t VName -> Grapher IntSet
onlyGraphedScalars t VName
vs = do
let is :: IntSet
is = (IntSet -> VName -> IntSet) -> IntSet -> t VName -> IntSet
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\IntSet
s VName
n -> Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n) IntSet
s) IntSet
IS.empty t VName
vs
IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher IntSet
onlyGraphedScalar VName
n = do
let i :: Int
i = VName -> Int
nameToId VName
n
IntSet
gss <- Grapher IntSet
getGraphedScalars
if Int -> IntSet -> Bool
IS.member Int
i IntSet
gss
then IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> IntSet
IS.singleton Int
i)
else IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (Constant PrimValue
_) = IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher IntSet
onlyGraphedScalar VName
n
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph :: Graph
stateGraph = Graph -> Graph
f (State -> Graph
stateGraph State
st)}
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars IntSet -> IntSet
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet -> IntSet
f (State -> IntSet
stateGraphedScalars State
st)}
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap -> CopyableMemoryMap
f (State -> CopyableMemoryMap
stateCopyableMemory State
st)}
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources :: Sources
stateSources = Sources -> Sources
f (State -> Sources
stateSources State
st)}
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd =
(CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (Int -> Int -> CopyableMemoryMap -> CopyableMemoryMap
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
i Int
bd)
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
fork_depth :: Int
fork_depth = Meta -> Int
metaForkDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaForkDepth :: Int
metaForkDepth = Int
fork_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
body_depth :: Int
body_depth = Meta -> Int
metaBodyDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaBodyDepth :: Int
metaBodyDepth = Int
body_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: Int -> Grapher a -> Grapher a
graphIdFor Int
i =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaGraphId :: Maybe Int
metaGraphId = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i}}
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
BodyStats
stats <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty}
a
_ <- Grapher a
m
BodyStats
stats' <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
stats BodyStats -> BodyStats -> BodyStats
forall a. Semigroup a => a -> a -> a
<> BodyStats
stats'}
BodyStats -> Grapher BodyStats
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = (Env -> Bool) -> Grapher Bool
forall a. (Env -> a) -> Grapher a
asks ((Env -> Bool) -> Grapher Bool) -> (Env -> Bool) -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ Name -> HostOnlyFuns -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fn (HostOnlyFuns -> Bool) -> (Env -> HostOnlyFuns) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = (Env -> Meta) -> Grapher Meta
forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Int
getBodyDepth = (Env -> Int) -> StateT State (Reader Env) Int
forall a. (Env -> a) -> Grapher a
asks (Meta -> Int
metaBodyDepth (Meta -> Int) -> (Env -> Meta) -> Env -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)