module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where
import Control.Monad
import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.Trans.State.Strict hiding (State)
import Control.Parallel.Strategies (parMap, rpar)
import Data.Foldable
import qualified Data.IntMap.Strict as IM
import Data.List (unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Sequence ((<|), (><), (|>))
import qualified Data.Text as T
import Futhark.Construct (fullSlice, sliceDim)
import Futhark.Error
import qualified Futhark.FreshNames as FN
import Futhark.IR.GPU
import Futhark.MonadFreshNames (VNameSource, getNameSource, putNameSource)
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
String
"reduce device synchronizations"
String
"Move host statements to device to reduce blocking memory operations."
Prog GPU -> PassM (Prog GPU)
forall (m :: * -> *). MonadFreshNames m => Prog GPU -> m (Prog GPU)
run
where
run :: Prog GPU -> m (Prog GPU)
run Prog GPU
prog = do
VNameSource
ns <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
let mt :: MigrationTable
mt = Prog GPU -> MigrationTable
analyseProg Prog GPU
prog
let st :: State
st = VNameSource -> State
initialState VNameSource
ns
let (Prog GPU
prog', State
st') = Reader MigrationTable (Prog GPU, State)
-> MigrationTable -> (Prog GPU, State)
forall r a. Reader r a -> r -> a
R.runReader (StateT State (Reader MigrationTable) (Prog GPU)
-> State -> Reader MigrationTable (Prog GPU, State)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
optimizeProgram Prog GPU
prog) State
st) MigrationTable
mt
VNameSource -> m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (State -> VNameSource
stateNameSource State
st')
Prog GPU -> m (Prog GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Prog GPU
prog'
optimizeProgram :: Prog GPU -> ReduceM (Prog GPU)
optimizeProgram :: Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
optimizeProgram (Prog Stms GPU
consts [FunDef GPU]
funs) = do
Stms GPU
consts' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
consts
[FunDef GPU]
funs' <- [StateT State (Reader MigrationTable) (FunDef GPU)]
-> StateT State (Reader MigrationTable) [FunDef GPU]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([StateT State (Reader MigrationTable) (FunDef GPU)]
-> StateT State (Reader MigrationTable) [FunDef GPU])
-> [StateT State (Reader MigrationTable) (FunDef GPU)]
-> StateT State (Reader MigrationTable) [FunDef GPU]
forall a b. (a -> b) -> a -> b
$ Strategy (StateT State (Reader MigrationTable) (FunDef GPU))
-> (FunDef GPU
-> StateT State (Reader MigrationTable) (FunDef GPU))
-> [FunDef GPU]
-> [StateT State (Reader MigrationTable) (FunDef GPU)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy (StateT State (Reader MigrationTable) (FunDef GPU))
forall a. Strategy a
rpar FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
optimizeFunDef [FunDef GPU]
funs
Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> [FunDef GPU] -> Prog GPU
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPU
consts' [FunDef GPU]
funs')
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU))
-> FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ FunDef GPU
fd {funDefBody :: Body GPU
funDefBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
Result
res' <- Result -> ReduceM Result
resolveResult Result
res
Body GPU -> ReduceM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = (Stms GPU -> Stm GPU -> ReduceM (Stms GPU))
-> Stms GPU -> Stms GPU -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
forall a. Monoid a => a
mempty
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
Bool
move <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
if Bool
move
then Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out Stm GPU
stm
else case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm of
BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
| Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
Maybe VName
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
case Maybe VName
dev of
Maybe VName
Nothing -> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Just VName
dst -> do
let dims :: [DimIndex SubExp]
dims = Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = Int -> [DimIndex SubExp] -> ([DimIndex SubExp], [DimIndex SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
let slice' :: Slice SubExp
slice' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = Exp GPU
forall rep. Exp rep
e}
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
| Pat [PatElem VName
n LetDec GPU
arr_t] <- Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
VName
v' <- VName -> ReduceM VName
resolveName VName
v
let v_kept_on_device :: Bool
v_kept_on_device = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
v'
Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk
case Bool
v_kept_on_device of
Bool
False -> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Bool
True
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
Just TypeBase (ShapeBase SubExp) NoUniqueness
t' <- Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t,
Bool
gpubody_ok -> do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall rep. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm})
Bool
True
| [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = Exp GPU
forall rep. Exp rep
e'}
in Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
Bool
True -> do
VName
n' <- VName -> ReduceM VName
newName VName
n
let dims' :: [SubExp]
dims' = [SubExp]
dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let arr_t' :: TypeBase (ShapeBase SubExp) NoUniqueness
arr_t' = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
arr_t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
let repl :: Stm GPU
repl = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall rep. Exp rep
e'
let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
let slice :: [DimIndex SubExp]
slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t)
let slice' :: [DimIndex SubExp]
slice' = [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
let idx :: Exp rep
idx = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
let index :: Stm GPU
index = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
idx
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
repl Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
index)
BasicOp {} ->
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Apply {} ->
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
If SubExp
cond (Body BodyDec GPU
_ Stms GPU
tstms0 Result
tres) (Body BodyDec GPU
_ Stms GPU
fstms0 Result
fres) (IfDec [BranchType GPU]
btypes IfSort
sort) ->
do
Stms GPU
tstms1 <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
tstms0
Stms GPU
fstms1 <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
fstms0
let bmerge :: ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
bmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms, Stms GPU
fstms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes
tr, SubExpRes
fr, TypeBase ExtShape NoUniqueness
bt) =
do
let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool) -> ReduceM VName -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
onHost SubExp
_ = Bool -> ReduceM Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Bool
tr_on_host <- SubExp -> ReduceM Bool
onHost (SubExpRes -> SubExp
resSubExp SubExpRes
tr)
Bool
fr_on_host <- SubExp -> ReduceM Bool
onHost (SubExpRes -> SubExp
resSubExp SubExpRes
fr)
if Bool
tr_on_host Bool -> Bool -> Bool
&& Bool
fr_on_host
then
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes
tr, SubExpRes
fr, TypeBase ExtShape NoUniqueness
bt) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms, Stms GPU
fstms)
else
do
let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
(Stms GPU
tstms', VName
tarr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
tstms (SubExpRes -> SubExp
resSubExp SubExpRes
tr) TypeBase (ShapeBase SubExp) NoUniqueness
t
(Stms GPU
fstms', VName
farr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
fstms (SubExpRes -> SubExp
resSubExp SubExpRes
fr) TypeBase (ShapeBase SubExp) NoUniqueness
t
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
let bt' :: TypeBase ExtShape NoUniqueness
bt' = TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase ExtShape NoUniqueness
forall u. TypeBase (ShapeBase SubExp) u -> TypeBase ExtShape u
staticShapes1 (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe')
let tr' :: SubExpRes
tr' = SubExpRes
tr {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
tarr}
let fr' :: SubExpRes
fr' = SubExpRes
fr {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
farr}
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe', SubExpRes
tr', SubExpRes
fr', TypeBase ExtShape NoUniqueness
bt') (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms', Stms GPU
fstms')
let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let zipped :: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Result
-> Result
-> [TypeBase ExtShape NoUniqueness]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Result
tres Result
fres [TypeBase ExtShape NoUniqueness]
[BranchType GPU]
btypes
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped', Stms GPU
tstms2, Stms GPU
fstms2) <- (([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU))
-> ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
SubExpRes, SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)],
Stms GPU, Stms GPU)
bmerge ([], Stms GPU
tstms1, Stms GPU
fstms1) [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped
let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes', Result
tres', Result
fres', [TypeBase ExtShape NoUniqueness]
btypes') = [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)], Result,
Result, [TypeBase ExtShape NoUniqueness])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. [a] -> [a]
reverse [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped')
let tbranch' :: Body GPU
tbranch' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
tstms2 Result
tres'
let fbranch' :: Body GPU
fbranch' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
fstms2 Result
fres'
let e' :: Exp GPU
e' = SubExp -> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> Exp GPU
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond Body GPU
tbranch' Body GPU
fbranch' ([TypeBase ExtShape NoUniqueness]
-> IfSort -> IfDec (TypeBase ExtShape NoUniqueness)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [TypeBase ExtShape NoUniqueness]
btypes' IfSort
sort)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
DoLoop [(FParam GPU, SubExp)]
ps LoopForm GPU
lf Body GPU
b -> do
([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params, LoopForm GPU
lform, Body GPU
body) <- ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn ([(FParam GPU, SubExp)]
ps, LoopForm GPU
lf, Body GPU
b)
let lmerge :: ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
MigrationStatus)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param Attrs
_ VName
pn TypeBase (ShapeBase SubExp) Uniqueness
pt, SubExp
pval), MigrationStatus
MoveToDevice) = do
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
(Stms GPU
stms', VName
arr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
pval (TypeBase (ShapeBase SubExp) Uniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl TypeBase (ShapeBase SubExp) Uniqueness
pt)
VName
pn' <- VName -> ReduceM VName
newName VName
pn
let pt' :: TypeBase (ShapeBase SubExp) Uniqueness
pt' = TypeBase (ShapeBase SubExp) NoUniqueness
-> Uniqueness -> TypeBase (ShapeBase SubExp) Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe') Uniqueness
Nonunique
let pval' :: SubExp
pval' = VName -> SubExp
Var VName
arr
let param' :: (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param' = (Attrs
-> VName
-> TypeBase (ShapeBase SubExp) Uniqueness
-> Param (TypeBase (ShapeBase SubExp) Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pn' TypeBase (ShapeBase SubExp) Uniqueness
pt', SubExp
pval')
VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident VName
pn (TypeBase (ShapeBase SubExp) Uniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl TypeBase (ShapeBase SubExp) Uniqueness
pt) Ident -> VName -> ReduceM ()
`movedTo` VName
pn'
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe', (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param') (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms')
lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
_ (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
_, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
_, MigrationStatus
UsedOnHost) =
String
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
forall a. String -> a
compilerBugS String
"optimizeStm: unhandled host usage of loop param"
lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param, MigrationStatus
StayOnHost) =
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms)
MigrationTable
mt <- ReduceM MigrationTable
ask
let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let mss :: [MigrationStatus]
mss = ((Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
-> MigrationStatus)
-> [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
-> [MigrationStatus]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n TypeBase (ShapeBase SubExp) Uniqueness
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
zipped', Stms GPU
out') <- (([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
MigrationStatus)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU))
-> ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
MigrationStatus)]
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
MigrationStatus)
-> StateT
State
(Reader MigrationTable)
([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
Stms GPU)
lmerge ([], Stms GPU
out) ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
-> [MigrationStatus]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
MigrationStatus)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params [MigrationStatus]
mss)
let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes', [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params') = [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)],
[(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. [a] -> [a]
reverse [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
zipped')
Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
optimizeBody Body GPU
body
let e' :: Exp GPU
e' = [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
params' LoopForm GPU
lform Body GPU
body'
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out' Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ u
_) = VName
a
getAcc TypeBase shape u
_ =
String -> VName
forall a. String -> a
compilerBugS
String
"Type error: WithAcc expression did not return accumulator."
let accs :: [(VName, WithAccInput GPU)]
accs = (TypeBase (ShapeBase SubExp) NoUniqueness
-> WithAccInput GPU -> (VName, WithAccInput GPU))
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [WithAccInput GPU]
-> [(VName, WithAccInput GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\TypeBase (ShapeBase SubExp) NoUniqueness
t WithAccInput GPU
i -> (TypeBase (ShapeBase SubExp) NoUniqueness -> VName
forall shape u. TypeBase shape u -> VName
getAcc TypeBase (ShapeBase SubExp) NoUniqueness
t, WithAccInput GPU
i)) (Lambda GPU -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
[WithAccInput GPU]
inputs' <- ((VName, WithAccInput GPU)
-> StateT State (Reader MigrationTable) (WithAccInput GPU))
-> [(VName, WithAccInput GPU)]
-> StateT State (Reader MigrationTable) [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName
-> WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU))
-> (VName, WithAccInput GPU)
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName
-> WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs
let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
let rewrite :: (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewrite (SubExpRes Certs
certs SubExp
se, TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe) =
do
SubExp
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
if SubExp
se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se'
then (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se, TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
else do
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
let t' :: TypeBase (ShapeBase SubExp) NoUniqueness
t' = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe'
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se', TypeBase (ShapeBase SubExp) NoUniqueness
t', PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe')
let len :: Int
len = [WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
let (Result
res0, Result
res1) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
let ([TypeBase (ShapeBase SubExp) NoUniqueness]
rts0, [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1) = Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ([TypeBase (ShapeBase SubExp) NoUniqueness],
[TypeBase (ShapeBase SubExp) NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Lambda GPU -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPU
lmd)
let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes0, [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1) = Int
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)],
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res1) [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
(Result
res1', [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1', [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1') <- [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> (Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> (Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]))
-> StateT
State
(Reader MigrationTable)
[(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> StateT
State
(Reader MigrationTable)
(Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> StateT
State
(Reader MigrationTable)
[(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewrite (Result
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res1 [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1)
let res' :: Result
res' = Result
res0 Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res1'
let rts' :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rts' = [TypeBase (ShapeBase SubExp) NoUniqueness]
rts0 [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1'
let pes' :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes0 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1'
let body' :: Body GPU
body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
let lmd' :: Lambda GPU
lmd' = Lambda GPU
lmd {lambdaBody :: Body GPU
lambdaBody = Body GPU
body', lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
rts'}
let e' :: Exp GPU
e' = [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
Op Op GPU
op -> do
HostOp GPU (SOAC GPU)
op' <- HostOp GPU (SOAC GPU) -> ReduceM (HostOp GPU (SOAC GPU))
forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp Op GPU
HostOp GPU (SOAC GPU)
op
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm {stmExp :: Exp GPU
stmExp = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op Op GPU
HostOp GPU (SOAC GPU)
op'})
where
addRead :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem dec)
-> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe@(PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
_), PatElem VName
dev dec
_)
| VName
n VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dev = Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
| Bool
otherwise = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
rewriteForIn ::
([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) ->
ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn loop :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop@([(FParam GPU, SubExp)]
_, WhileLoop {}, Body GPU
_) =
([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
LoopForm GPU, Body GPU)
-> StateT
State
(Reader MigrationTable)
([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
LoopForm GPU, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
LoopForm GPU, Body GPU)
([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop
rewriteForIn ([(FParam GPU, SubExp)]
params, ForLoop VName
i IntType
t SubExp
n [(LParam GPU, VName)]
elems, Body GPU
body) = do
MigrationTable
mt <- ReduceM MigrationTable
ask
let ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
elems', Stms GPU
stms') = ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU))
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU)
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (MigrationTable
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
Stms GPU)
forall dec.
Typed dec =>
MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt) ([], Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
[(LParam GPU, VName)]
elems
([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
LoopForm GPU, Body GPU)
-> StateT
State
(Reader MigrationTable)
([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
LoopForm GPU, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
params, VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
t SubExp
n [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
[(LParam GPU, VName)]
elems', Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'})
where
inline :: MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt (Param dec
x, VName
arr) ([(Param dec, VName)]
arrs, Stms GPU
stms)
| VName
pn <- Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
x,
Bool -> Bool
not (VName -> MigrationTable -> Bool
usedOnHost VName
pn MigrationTable
mt) =
let pt :: TypeBase (ShapeBase SubExp) NoUniqueness
pt = Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
forall t. Typed t => t -> TypeBase (ShapeBase SubExp) NoUniqueness
typeOf Param dec
x
stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pn TypeBase (ShapeBase SubExp) NoUniqueness
pt) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
forall u. VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) NoUniqueness
pt)
in ([(Param dec, VName)]
arrs, Stm GPU
stm Stm GPU -> Stms GPU -> Stms GPU
forall a. a -> Seq a -> Seq a
<| Stms GPU
stms)
| Bool
otherwise =
((Param dec
x, VName
arr) (Param dec, VName) -> [(Param dec, VName)] -> [(Param dec, VName)]
forall a. a -> [a] -> [a]
: [(Param dec, VName)]
arrs, Stms GPU
stms)
index :: VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) u
of_type =
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase (ShapeBase SubExp) u -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) u
of_type)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName
-> WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
Bool
device_only <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
if Bool
device_only
then do
Lambda GPU
op' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
op
WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
else do
let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op
Stms GPU
stms' <- ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a. ReduceM a -> ReduceM a
noGPUBody (ReduceM (Stms GPU) -> ReduceM (Stms GPU))
-> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
let op' :: Lambda GPU
op' = Lambda GPU
op {lambdaBody :: Body GPU
lambdaBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) =
SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
[SegBinOp GPU]
ops' <- (SegBinOp GPU
-> StateT State (Reader MigrationTable) (SegBinOp GPU))
-> [SegBinOp GPU]
-> StateT State (Reader MigrationTable) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [SegBinOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
[SegBinOp GPU]
ops' <- (SegBinOp GPU
-> StateT State (Reader MigrationTable) (SegBinOp GPU))
-> [SegBinOp GPU]
-> StateT State (Reader MigrationTable) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [SegBinOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
[HistOp GPU]
ops' <- (HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU))
-> [HistOp GPU]
-> StateT State (Reader MigrationTable) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [HistOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SizeOp SizeOp
op) =
HostOp GPU op -> ReduceM (HostOp GPU op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp GPU op
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
String -> ReduceM (HostOp GPU op)
forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types Body GPU
body) =
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU -> HostOp GPU op
forall rep op.
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep -> HostOp rep op
GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types (Body GPU -> HostOp GPU op)
-> ReduceM (Body GPU) -> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)
type ReduceM = StateT State (R.Reader MigrationTable)
data State = State
{
State -> VNameSource
stateNameSource :: VNameSource,
State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
State -> Bool
stateGPUBodyOk :: Bool
}
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
State :: VNameSource
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Bool
-> State
State
{ stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
stateMigrated :: IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated = IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Monoid a => a
mempty,
stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
}
asks :: (MigrationTable -> a) -> ReduceM a
asks :: (MigrationTable -> a) -> ReduceM a
asks = ReaderT MigrationTable Identity a -> ReduceM a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT MigrationTable Identity a -> ReduceM a)
-> ((MigrationTable -> a) -> ReaderT MigrationTable Identity a)
-> (MigrationTable -> a)
-> ReduceM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MigrationTable -> a) -> ReaderT MigrationTable Identity a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks
ask :: ReduceM MigrationTable
ask :: ReduceM MigrationTable
ask = ReaderT MigrationTable Identity MigrationTable
-> ReduceM MigrationTable
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT MigrationTable Identity MigrationTable
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
Bool
prev <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk
(State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
False}
a
res <- ReduceM a
m
(State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
prev}
a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res
newName :: VName -> ReduceM VName
newName :: VName -> ReduceM VName
newName VName
n = do
State
st <- StateT State (Reader MigrationTable) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let ns :: VNameSource
ns = State -> VNameSource
stateNameSource State
st
let (VName
n', VNameSource
ns') = VNameSource -> VName -> (VName, VNameSource)
FN.newName VNameSource
ns VName
n
State -> ReduceM ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State
st {stateNameSource :: VNameSource
stateNameSource = VNameSource
ns'})
VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t) = do
let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
VName
dev <- VName -> ReduceM VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let dev_t :: TypeBase (ShapeBase SubExp) NoUniqueness
dev_t = TypeBase (ShapeBase SubExp) NoUniqueness
t TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
dev TypeBase (ShapeBase SubExp) NoUniqueness
dev_t)
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x TypeBase (ShapeBase SubExp) NoUniqueness
t) VName
arr =
(State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let migrated :: IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated = State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated State
st
entry :: (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry = (VName -> Name
baseName VName
x, TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, Bool
host)
migrated' :: IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated' = Int
-> (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated
in State
st {stateMigrated :: IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated = IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated'}
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (VName
dev, Stms GPU
stms) = do
Bool
used <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (VName -> MigrationTable -> Bool
usedOnHost (VName -> MigrationTable -> Bool)
-> VName -> MigrationTable -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
if Bool
used
then PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Ident -> VName -> ReduceM ()
`aliasedBy` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (VName -> Exp GPU
eIndex VName
dev))
else PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Ident -> VName -> ReduceM ()
`movedTo` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
_, Bool
True) ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
name, TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, Bool
_) ->
do
VName
n' <- VName -> ReduceM VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t) (VName -> Exp GPU
eIndex VName
arr)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, VName
n')
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ())
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = Maybe VName -> ReduceM (Maybe VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe VName
forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
Maybe VName -> ReduceM (Maybe VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe VName -> ReduceM (Maybe VName))
-> Maybe VName -> ReduceM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ ((Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> VName)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> VName
arr) Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se TypeBase (ShapeBase SubExp) NoUniqueness
t = do
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- case SubExp
se of
Var VName
n -> Int
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
SubExp
_ -> Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> StateT
State
(Reader MigrationTable)
(Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Maybe a
Nothing
case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing -> do
Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk
case SubExp
se of
Var VName
n | Bool
gpubody_ok -> do
VName
n' <- VName -> ReduceM VName
newName VName
n
let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPU
stm)
let dev :: VName
dev = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall a b. (a -> b) -> a -> b
$ [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody, VName
dev)
Var VName
n -> do
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t)
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
SubExp
_ -> do
let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t)
let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(Maybe
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
State
(Reader MigrationTable)
(IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
(Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
_, Bool
True) -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp) -> ReduceM VName -> ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = SubExp -> ReduceM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (SubExp -> SubExpRes) -> ReduceM SubExp -> ReduceM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = (SubExpRes -> ReduceM SubExpRes) -> Result -> ReduceM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t')))
| Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Pat (LetDec GPU)
pat})
moveStm Stms GPU
out Stm GPU
stm = do
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)
let arrs :: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
arrs = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
(Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody) [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
arrs
where
addRead :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe@(PatElem VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
t), PatElem VName
dev TypeBase (ShapeBase SubExp) NoUniqueness
dev_t) =
let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = Stms GPU -> f (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> f (Stms GPU)) -> Stms GPU -> f (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Exp GPU
e
add :: BasicOp -> ReduceM (Stms GPU)
add = Exp GPU -> ReduceM (Stms GPU)
forall (f :: * -> *). Applicative f => Exp GPU -> f (Stms GPU)
add' (Exp GPU -> ReduceM (Stms GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp
in case TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
dev_t of
Int
0 -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
Int
1 | TypeBase (ShapeBase SubExp) NoUniqueness
t TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> Exp GPU -> ReduceM (Stms GPU)
forall (f :: * -> *). Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
Int
1 -> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
Int
_ -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (TypeBase (ShapeBase SubExp) NoUniqueness
-> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase (ShapeBase SubExp) NoUniqueness
dev_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
(Stm GPU
stm, RState
st) <- RewriteM (Stm GPU) -> RState -> ReduceM (Stm GPU, RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
let prologue :: Stms GPU
prologue = RState -> Stms GPU
rewritePrologue RState
st
let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <- [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
State
(Reader MigrationTable)
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
State
(Reader MigrationTable)
(Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
State
(Reader MigrationTable)
[PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
let types :: [TypeBase (ShapeBase SubExp) NoUniqueness]
types = (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
let res :: Result
res = (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> SubExpRes)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
forall a. Monoid a => a
mempty (SubExp -> SubExpRes)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> SubExp)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
let body :: Body GPU
body = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
let e :: Exp GPU
e = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU -> HostOp GPU (SOAC GPU)
forall rep op.
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep -> HostOp rep op
GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types Body GPU
body)
Stm GPU -> ReduceM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
e)
type RewriteM = StateT RState ReduceM
data RState = RState
{
RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
RState -> Stms GPU
rewritePrologue :: Stms GPU
}
initialRState :: RState
initialRState :: RState
initialRState =
RState :: IntMap VName -> Stms GPU -> RState
RState
{ rewriteRenames :: IntMap VName
rewriteRenames = IntMap VName
forall a. Monoid a => a
mempty,
rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
forall a. Monoid a => a
mempty
}
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (SegBinOp GPU -> Lambda GPU
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU
op {segBinOpLambda :: Lambda GPU
segBinOpLambda = Lambda GPU
f'})
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (HistOp GPU -> Lambda GPU
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPU
op)
HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
op {histOp :: Lambda GPU
histOp = Lambda GPU
f'})
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Lambda GPU -> ReduceM (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
f {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
(Body GPU
body', Stms GPU
prologue) <- Body GPU -> ReduceM (Body GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
Body GPU -> ReduceM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
body' {bodyStms :: Stms GPU
bodyStms = Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body'}
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
(KernelBody GPU
kbody', Stms GPU
prologue) <- KernelBody GPU -> ReduceM (KernelBody GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPU
kbody' {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< KernelBody GPU -> Stms GPU
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody'}
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
let from :: [VName]
from = Names -> [VName]
namesToList (a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
([VName]
to, RState
st) <- StateT RState ReduceM [VName]
-> RState -> ReduceM ([VName], RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT RState ReduceM VName
rename [VName]
from) RState
initialRState
let rename_map :: Map VName VName
rename_map = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
(a, Stms GPU) -> ReduceM (a, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName -> a -> a
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
rename_map a
x, RState -> Stms GPU
rewritePrologue RState
st)
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT RState ReduceM VName
rewriteName VName
n = do
VName
n' <- ReduceM VName -> StateT RState ReduceM VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> ReduceM VName
newName VName
n)
(RState -> RState) -> StateT RState ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st -> RState
st {rewriteRenames :: IntMap VName
rewriteRenames = Int -> VName -> IntMap VName -> IntMap VName
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) VName
n' (RState -> IntMap VName
rewriteRenames RState
st)}
VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
Result
res' <- Result -> RewriteM Result
renameResult Result
res
Body GPU -> RewriteM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = (Stms GPU -> Stm GPU -> RewriteM (Stms GPU))
-> Stms GPU -> Stms GPU -> RewriteM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
forall a. Monoid a => a
mempty
where
rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
Stm GPU
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
Stms GPU -> RewriteM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> RewriteM (Stms GPU))
-> Stms GPU -> RewriteM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm' of
Op (GPUBody _ (Body _ stms res)) ->
let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
in (Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
-> Stms GPU)
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
SubExpRes)]
-> Stms GPU
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
-> Stms GPU
bnd (Stms GPU
out Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Result
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Result
res)
Exp GPU
_ -> Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
-> Stms GPU
bnd Stms GPU
out (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes Certs
cs SubExp
se)
| Just TypeBase (ShapeBase SubExp) NoUniqueness
t' <- Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall t. Typed t => t -> TypeBase (ShapeBase SubExp) NoUniqueness
typeOf PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe) =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t')
| Bool
otherwise =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' <- Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePat Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat
StmAux ()
aux' <- StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux ()
StmAux (ExpDec GPU)
aux
Stm GPU -> RewriteM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' StmAux ()
StmAux (ExpDec GPU)
aux' Exp GPU
e')
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePat Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
RState ReduceM [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
RState
ReduceM
(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
RState ReduceM [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePatElem (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePatElem (PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t) = do
VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
TypeBase (ShapeBase SubExp) NoUniqueness
t' <- TypeBase (ShapeBase SubExp) NoUniqueness
-> RewriteM (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) NoUniqueness
t
PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t')
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux (StmAux Certs
certs Attrs
attrs ()
_) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
StmAux () -> RewriteM (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
certs' Attrs
attrs ())
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU))
-> Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU
-> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$
Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
{ mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
mapOnVName :: VName -> StateT RState ReduceM VName
mapOnVName = VName -> StateT RState ReduceM VName
rename,
mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = RetType GPU -> StateT RState ReduceM (RetType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType,
mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = BranchType GPU -> StateT RState ReduceM (BranchType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType,
mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = FParam GPU -> StateT RState ReduceM (FParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = LParam GPU -> StateT RState ReduceM (LParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = StateT RState ReduceM (HostOp GPU (SOAC GPU))
-> HostOp GPU (SOAC GPU)
-> StateT RState ReduceM (HostOp GPU (SOAC GPU))
forall a b. a -> b -> a
const StateT RState ReduceM (HostOp GPU (SOAC GPU))
forall a. a
opError
}
where
opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
TypeBase (ShapeBase SubExp) u
t' <- TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) u
t
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs
-> VName
-> TypeBase (ShapeBase SubExp) u
-> Param (TypeBase (ShapeBase SubExp) u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n' TypeBase (ShapeBase SubExp) u
t')
rename :: VName -> RewriteM VName
rename :: VName -> StateT RState ReduceM VName
rename VName
n = do
RState
st <- StateT RState ReduceM RState
forall (m :: * -> *) s. Monad m => StateT s m s
get
let renames :: IntMap VName
renames = RState -> IntMap VName
rewriteRenames RState
st
let idx :: Int
idx = VName -> Int
baseTag VName
n
case Int -> IntMap VName -> Maybe VName
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
idx IntMap VName
renames of
Just VName
n' -> VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
Maybe VName
_ ->
do
let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
(Stms GPU
stms', VName
n') <- ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName))
-> ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
(RState -> RState) -> StateT RState ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st' ->
RState
st'
{ rewriteRenames :: IntMap VName
rewriteRenames = Int -> VName -> IntMap VName -> IntMap VName
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
idx VName
n' IntMap VName
renames,
rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
stms'
}
VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = (SubExpRes -> StateT RState ReduceM SubExpRes)
-> Result -> RewriteM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
SubExp
se' <- SubExp -> StateT RState ReduceM SubExp
renameSubExp SubExp
se
SubExpRes -> StateT RState ReduceM SubExpRes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs' SubExp
se')
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT RState ReduceM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT RState ReduceM VName
rename (Certs -> [VName]
unCerts Certs
cs)
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp)
-> StateT RState ReduceM VName -> StateT RState ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> StateT RState ReduceM VName
rename VName
n
renameSubExp SubExp
se = SubExp -> StateT RState ReduceM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
renameType :: TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase ExtShape u -> m (TypeBase ExtShape u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp