-- | This module implements an optimization that migrates host
-- statements into 'GPUBody' kernels to reduce the number of
-- host-device synchronizations that occur when a scalar variable is
-- written to or read from device memory. Which statements that should
-- be migrated are determined by a 'MigrationTable' produced by the
-- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable" module; this module
-- merely performs the migration and rewriting dictated by that table.
module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State hiding (State)
import Data.Bifunctor (second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.List (transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Sequence ((<|), (><), (|>))
import Data.Text qualified as T
import Futhark.Construct (fullSlice, mkBody, sliceDim)
import Futhark.Error
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute

-- | An optimization pass that migrates host statements into 'GPUBody' kernels
-- to reduce the number of host-device synchronizations.
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"reduce device synchronizations"
    String
"Move host statements to device to reduce blocking memory operations."
    forall a b. (a -> b) -> a -> b
$ \Prog GPU
prog -> do
      let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog
          consts_mt :: MigrationTable
consts_mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog) (forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPU
prog)
      Stms GPU
consts <- forall {m :: * -> *}.
MonadFreshNames m =>
MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPU
prog
      [FunDef GPU]
funs <- forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass (forall {m :: * -> *}.
MonadFreshNames m =>
HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt) (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Prog GPU
prog {progConsts :: Stms GPU
progConsts = Stms GPU
consts, progFuns :: [FunDef GPU]
progFuns = [FunDef GPU]
funs}
  where
    onConsts :: MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt Stms GPU
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
consts_mt (Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms)
    onFun :: HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt FunDef GPU
fd = do
      let mt :: MigrationTable
mt = MigrationTable
consts_mt forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd
      forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd)

--------------------------------------------------------------------------------
--                            AD HOC OPTIMIZATION                             --
--------------------------------------------------------------------------------

-- | Optimize a function definition. Its type signature will remain unchanged.
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
  let body :: Body GPU
body = forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPU
fd
  Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FunDef GPU
fd {funDefBody :: Body GPU
funDefBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}

-- | Optimize a body. Scalar results may be replaced with single-element arrays.
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
  Result
res' <- Result -> ReduceM Result
resolveResult Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')

-- | Optimize a sequence of statements.
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm forall a. Monoid a => a
mempty

-- | Optimize a single statement, rewriting it into one or more statements to
-- be appended to the provided 'Stms'. Only variables with continued host usage
-- will remain in scope if their statement is migrated.
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
  Bool
move <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
  if Bool
move
    then Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out Stm GPU
stm
    else case forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm of
      BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
        | Just [SubExp]
_ <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
            -- Rewrite the Update if its write value has been migrated. Copying
            -- is faster than doing a synchronous write, so we use the device
            -- array even if the value has been made available to the host.
            Maybe VName
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
            case Maybe VName
dev of
              Maybe VName
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
              Just VName
dst -> do
                -- Transform the single element Update into a slice Update.
                let dims :: [DimIndex SubExp]
dims = forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
                let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
                let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
                let slice' :: Slice SubExp
slice' = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
                let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
                let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} {rep :: k}. Exp rep
e}

                forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
      BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
        | Pat [PatElem VName
n LetDec GPU
arr_t] <- forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
            -- A Replicate can be rewritten to not require its replication value
            -- to be available on host. If its value is migrated the Replicate
            -- thus needs to be transformed.
            --
            -- If the inner dimension of the replication array is one then the
            -- rewrite can be performed more efficiently than the general case.
            VName
v' <- VName -> ReduceM VName
resolveName VName
v
            let v_kept_on_device :: Bool
v_kept_on_device = VName
v forall a. Eq a => a -> a -> Bool
/= VName
v'

            Bool
gpubody_ok <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk

            case Bool
v_kept_on_device of
              Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
              Bool
True
                | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
                  Just Type
t' <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 LetDec GPU
arr_t,
                  Bool
gpubody_ok -> do
                    let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
                    let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
                    let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
                    let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) forall {k} {rep :: k}. Exp rep
e'

                    -- `gpu { v }` is slightly faster than `replicate 1 v` and
                    -- can merge with the GPUBody that v was computed by.
                    Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm})
              Bool
True
                | forall a. [a] -> a
last [SubExp]
dims forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
                    let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
                        stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} {rep :: k}. Exp rep
e'}
                     in forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
              Bool
True -> do
                VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
                -- v_kept_on_device implies that v is a scalar.
                let dims' :: [SubExp]
dims' = [SubExp]
dims forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
                let arr_t' :: Type
arr_t' = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (forall shape u. TypeBase shape u -> PrimType
elemType LetDec GPU
arr_t) (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
                let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
arr_t']
                let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
                let repl :: Stm GPU
repl = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) forall {k} {rep :: k}. Exp rep
e'

                let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
                let slice :: [DimIndex SubExp]
slice = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims LetDec GPU
arr_t)
                let slice' :: [DimIndex SubExp]
slice' = [DimIndex SubExp]
slice forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
                let idx :: Exp rep
idx = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
                let index :: Stm GPU
index = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) StmAux ()
aux forall {k} {rep :: k}. Exp rep
idx

                forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
repl forall a. Seq a -> a -> Seq a
|> Stm GPU
index)
      BasicOp {} ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
      Apply {} ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
      Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody (MatchDec [BranchType GPU]
btypes MatchSort
sort) -> do
        -- Rewrite branches.
        [Stms GPU]
cases_stms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Stms GPU -> ReduceM (Stms GPU)
optimizeStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
        let cases_res :: [Result]
cases_res = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
        Stms GPU
defbody_stms <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
defbody
        let defbody_res :: Result
defbody_res = forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
defbody

        -- Ensure return values and types match if one or both branches
        -- return a result that now reside on device.
        let bmerge :: ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms) (PatElem Type
pe, Result
reses, ExtType
bt) = do
              let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v ==) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
                  onHost SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

              Bool
on_host <- forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ReduceM Bool
onHost forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
reses

              if Bool
on_host
                then -- No result resides on device ==> nothing to do.
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, Result
reses, ExtType
bt) forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms)
                else do
                  -- Otherwise, ensure all results are migrated.
                  ([Stms GPU]
all_stms', [VName]
arrs) <-
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
                      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Stms GPU]
all_stms Result
reses) forall a b. (a -> b) -> a -> b
$ \(Stms GPU
stms, SubExpRes
res) ->
                        Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms (SubExpRes -> SubExp
resSubExp SubExpRes
res) (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe)

                  PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
                  let bt' :: ExtType
bt' = forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase ExtSize) u
staticShapes1 (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe')
                      reses' :: Result
reses' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
reses) (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', Result
reses', ExtType
bt') forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms')

            pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
        ([(PatElem Type, Result, ExtType)]
acc, ~(Stms GPU
defbody_stms' : [Stms GPU]
cases_stms')) <-
          forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([], Stms GPU
defbody_stms forall a. a -> [a] -> [a]
: [Stms GPU]
cases_stms) forall a b. (a -> b) -> a -> b
$
            forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes (forall a. [[a]] -> [[a]]
transpose forall a b. (a -> b) -> a -> b
$ Result
defbody_res forall a. a -> [a] -> [a]
: [Result]
cases_res) [BranchType GPU]
btypes
        let ([PatElem Type]
pes', [Result]
reses, [ExtType]
btypes') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (forall a. [a] -> [a]
reverse [(PatElem Type, Result, ExtType)]
acc)

        -- Rewrite statement.
        let cases' :: [Case (Body GPU)]
cases' =
              forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body GPU)]
cases) forall a b. (a -> b) -> a -> b
$
                forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody [Stms GPU]
cases_stms' forall a b. (a -> b) -> a -> b
$
                  forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$
                    forall a. [[a]] -> [[a]]
transpose [Result]
reses
            defbody' :: Body GPU
defbody' = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
defbody_stms' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> a
head [Result]
reses
            e' :: Exp GPU
e' = forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
btypes' MatchSort
sort)
            stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

        -- Read migrated scalars that are used on host.
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
      DoLoop [(FParam GPU, SubExp)]
ps LoopForm GPU
lf Body GPU
b -> do
        -- Enable the migration of for-in loop variables.
        ([(Param DeclType, SubExp)]
params, LoopForm GPU
lform, Body GPU
body) <- ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn ([(FParam GPU, SubExp)]
ps, LoopForm GPU
lf, Body GPU
b)

        -- Update statement bound variables and parameters if their values
        -- have been migrated to device.
        let lmerge :: ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
     ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param DeclType, SubExp)
param, MigrationStatus
StayOnHost) =
              forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, (Param DeclType, SubExp)
param) forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds)
            lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param Attrs
_ VName
pn DeclType
pt, SubExp
pval), MigrationStatus
_) = do
              -- Migrate the bound variable.
              PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe

              -- Move the initial value to device if not already there to
              -- ensure that the parameter argument and loop return value
              -- converge.
              (Stms GPU
stms', VName
arr) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
pval (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
pt)

              -- Migrate the parameter.
              VName
pn' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
pn
              let pt' :: DeclType
pt' = forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe') Uniqueness
Nonunique
              let pval' :: SubExp
pval' = VName -> SubExp
Var VName
arr
              let param' :: (Param DeclType, SubExp)
param' = (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
pn' DeclType
pt', SubExp
pval')

              -- Record the migration and rebind the parameter inside the
              -- loop body if necessary.
              Stms GPU
rebinds' <- (PatElem Type
pe {patElemName :: VName
patElemName = VName
pn}) PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
pn', Stms GPU
rebinds)

              forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', (Param DeclType, SubExp)
param') forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms', Stms GPU
rebinds')

        MigrationTable
mt <- forall r (m :: * -> *). MonadReader r m => m r
ask
        let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
        let mss :: [MigrationStatus]
mss = forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n DeclType
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param DeclType, SubExp)]
params
        ([(PatElem Type, (Param DeclType, SubExp))]
zipped', Stms GPU
out', Stms GPU
rebinds) <-
          forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
     ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([], Stms GPU
out, forall a. Monoid a => a
mempty) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes [(Param DeclType, SubExp)]
params [MigrationStatus]
mss)
        let ([PatElem Type]
pes', [(Param DeclType, SubExp)]
params') = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. [a] -> [a]
reverse [(PatElem Type, (Param DeclType, SubExp))]
zipped')

        -- Rewrite body.
        let body1 :: Body GPU
body1 = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
rebinds forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body}
        Body GPU
body2 <- Body GPU -> ReduceM (Body GPU)
optimizeBody Body GPU
body1
        let zipped :: [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped =
              forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
                [MigrationStatus]
mss
                (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body2)
                (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
                (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes)
        let rstore :: (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
StayOnHost, SubExpRes
r, SubExp
_, Type
_) =
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms, SubExpRes
r forall a. a -> [a] -> [a]
: Result
res)
            rstore (Stms GPU
bstms, Result
res) (MigrationStatus
_, SubExpRes Certs
certs SubExp
_, SubExp
se, Type
t) = do
              (Stms GPU
bstms', VName
dev) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
bstms SubExp
se Type
t
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms', Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (VName -> SubExp
Var VName
dev) forall a. a -> [a] -> [a]
: Result
res)
        (Stms GPU
bstms, Result
res) <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body2, []) [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped
        let body3 :: Body GPU
body3 = Body GPU
body2 {bodyStms :: Stms GPU
bodyStms = Stms GPU
bstms, bodyResult :: Result
bodyResult = forall a. [a] -> [a]
reverse Result
res}

        -- Rewrite statement.
        let e' :: Exp GPU
e' = forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
params' LoopForm GPU
lform Body GPU
body3
        let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

        -- Read migrated scalars that are used on host.
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out' forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
      WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
        let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [Type]
_ u
_) = VName
a
            getAcc TypeBase shape u
_ =
              forall a. String -> a
compilerBugS
                String
"Type error: WithAcc expression did not return accumulator."

        let accs :: [(VName, WithAccInput GPU)]
accs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t WithAccInput GPU
i -> (forall {shape} {u}. TypeBase shape u -> VName
getAcc Type
t, WithAccInput GPU
i)) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
        [WithAccInput GPU]
inputs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs

        let body :: Body GPU
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
        Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)

        let rewrite :: (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe) =
              do
                SubExp
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
                if SubExp
se forall a. Eq a => a -> a -> Bool
== SubExp
se'
                  then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe)
                  else do
                    PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
                    let t' :: Type
t' = forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe'
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se', Type
t', PatElem Type
pe')

        -- Rewrite non-accumulator results that have been migrated.
        --
        -- Accumulator return values do not map to arrays one-to-one but
        -- one-to-many. They are not transformed however and can be mapped
        -- as a no-op.
        let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
        let (Result
res0, Result
res1) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
        let ([Type]
rts0, [Type]
rts1) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd)
        let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
        let ([PatElem Type]
pes0, [PatElem Type]
pes1) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
pes forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res1) [PatElem Type]
pes
        (Result
res1', [Type]
rts1', [PatElem Type]
pes1') <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res1 [Type]
rts1 [PatElem Type]
pes1)
        let res' :: Result
res' = Result
res0 forall a. [a] -> [a] -> [a]
++ Result
res1'
        let rts' :: [Type]
rts' = [Type]
rts0 forall a. [a] -> [a] -> [a]
++ [Type]
rts1'
        let pes' :: [PatElem Type]
pes' = [PatElem Type]
pes0 forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes1'

        -- Rewrite statement.
        let body' :: Body GPU
body' = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
        let lmd' :: Lambda GPU
lmd' = Lambda GPU
lmd {lambdaBody :: Body GPU
lambdaBody = Body GPU
body', lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rts'}
        let e' :: Exp GPU
e' = forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
        let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

        -- Read migrated scalars that are used on host.
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
      Op Op GPU
op -> do
        HostOp GPU (SOAC GPU)
op' <- forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp Op GPU
op
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} (rep :: k). Op rep -> Exp rep
Op HostOp GPU (SOAC GPU)
op'})
  where
    addRead :: Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
n Type
_), PatElem VName
dev dec
_)
      | VName
n forall a. Eq a => a -> a -> Bool
== VName
dev = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
      | Bool
otherwise = PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)

-- | Rewrite a for-in loop such that relevant source array reads can be delayed.
rewriteForIn ::
  ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) ->
  ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn loop :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop@([(FParam GPU, SubExp)]
_, WhileLoop {}, Body GPU
_) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop
rewriteForIn ([(FParam GPU, SubExp)]
params, ForLoop VName
i IntType
t SubExp
n [(LParam GPU, VName)]
elems, Body GPU
body) = do
  MigrationTable
mt <- forall r (m :: * -> *). MonadReader r m => m r
ask
  let ([(Param Type, VName)]
elems', Stms GPU
stms') = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall {dec}.
Typed dec =>
MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt) ([], forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body) [(LParam GPU, VName)]
elems
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)]
params, forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
t SubExp
n [(Param Type, VName)]
elems', Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'})
  where
    inline :: MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt (Param dec
x, VName
arr) ([(Param dec, VName)]
arrs, Stms GPU
stms)
      | VName
pn <- forall dec. Param dec -> VName
paramName Param dec
x,
        Bool -> Bool
not (VName -> MigrationTable -> Bool
usedOnHost VName
pn MigrationTable
mt) =
          let pt :: Type
pt = forall t. Typed t => t -> Type
typeOf Param dec
x
              stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
pn Type
pt) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ forall {u}. VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr Type
pt)
           in ([(Param dec, VName)]
arrs, Stm GPU
stm forall a. a -> Seq a -> Seq a
<| Stms GPU
stms)
      | Bool
otherwise =
          ((Param dec
x, VName
arr) forall a. a -> [a] -> [a]
: [(Param dec, VName)]
arrs, Stms GPU
stms)

    index :: VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) u
of_type =
      VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) u
of_type)

-- | Optimize an accumulator input. The 'VName' is the accumulator token.
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
  Bool
device_only <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
  if Bool
device_only
    then do
      Lambda GPU
op' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
op
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
    else do
      let body :: Body GPU
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
op
      -- To pass type check neither parameters nor results can change.
      --
      -- op may be used on both host and device so we must avoid introducing
      -- any GPUBody statements.
      Stms GPU
stms' <- forall a. ReduceM a -> ReduceM a
noGPUBody forall a b. (a -> b) -> a -> b
$ Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
      let op' :: Lambda GPU
op' = Lambda GPU
op {lambdaBody :: Body GPU
lambdaBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))

-- | Optimize a host operation. 'Index' statements are added to kernel code
-- that depends on migrated scalars.
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp :: forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody)) =
  forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
  [SegBinOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
  forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
  [SegBinOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
  forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
  [HistOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
  forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SizeOp SizeOp
op) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
  -- These should all have been taken care of in the unstreamGPU pass.
  forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [Type]
types Body GPU
body) =
  forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body

--------------------------------------------------------------------------------
--                               COMMON HELPERS                               --
--------------------------------------------------------------------------------

-- | Append the given string to a name.
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)

--------------------------------------------------------------------------------
--                             MIGRATION - TYPES                              --
--------------------------------------------------------------------------------

-- | The monad used to perform migration-based synchronization reductions.
newtype ReduceM a = ReduceM (StateT State (Reader MigrationTable) a)
  deriving
    ( forall a b. a -> ReduceM b -> ReduceM a
forall a b. (a -> b) -> ReduceM a -> ReduceM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ReduceM b -> ReduceM a
$c<$ :: forall a b. a -> ReduceM b -> ReduceM a
fmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
$cfmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
Functor,
      Functor ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
$c<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
liftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
$c<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
pure :: forall a. a -> ReduceM a
$cpure :: forall a. a -> ReduceM a
Applicative,
      Applicative ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ReduceM a
$creturn :: forall a. a -> ReduceM a
>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
$c>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
Monad,
      MonadState State,
      MonadReader MigrationTable
    )

runReduceM :: MonadFreshNames m => MigrationTable -> ReduceM a -> m a
runReduceM :: forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (ReduceM StateT State (Reader MigrationTable) a
m) = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second State -> VNameSource
stateNameSource (forall r a. Reader r a -> r -> a
runReader (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT State (Reader MigrationTable) a
m (VNameSource -> State
initialState VNameSource
src)) MigrationTable
mt)

instance MonadFreshNames ReduceM where
  getNameSource :: ReduceM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ReduceM ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- | The state used by a 'ReduceM' monad.
data State = State
  { -- | A source to generate new 'VName's from.
    State -> VNameSource
stateNameSource :: VNameSource,
    -- | A table of variables in the original program which have been migrated
    -- to device. Each variable maps to a tuple that describes:
    --   * 'baseName' of the original variable.
    --   * Type of the original variable.
    --   * Name of the single element array holding the migrated value.
    --   * Whether the original variable still can be used on the host.
    State -> IntMap (Name, Type, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
    -- | Whether non-migration optimizations may introduce 'GPUBody' kernels at
    -- the current location.
    State -> Bool
stateGPUBodyOk :: Bool
  }

--------------------------------------------------------------------------------
--                           MIGRATION - PRIMITIVES                           --
--------------------------------------------------------------------------------

-- | An initial state to use when running a 'ReduceM' monad.
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
  State
    { stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
      stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = forall a. Monoid a => a
mempty,
      stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
    }

-- | Perform non-migration optimizations without introducing any GPUBody
-- kernels.
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: forall a. ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
  Bool
prev <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
False}
  a
res <- ReduceM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
prev}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res

-- | Create a 'PatElem' that binds the array of a migrated variable binding.
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (PatElem VName
n Type
t) = do
  let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
  VName
dev <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
  let dev_t :: Type
dev_t = Type
t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. VName -> dec -> PatElem dec
PatElem VName
dev Type
dev_t)

-- | @x `movedTo` arr@ registers that the value of @x@ has been migrated to
-- @arr[0]@.
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False

-- | @x `aliasedBy` arr@ registers that the value of @x@ also is available on
-- device as @arr[0]@.
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True

-- | @recordMigration host x arr@ records the migration of variable @x@ to
-- @arr[0]@. If @host@ then the original binding can still be used on host.
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x Type
t) VName
arr =
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
    let migrated :: IntMap (Name, Type, VName, Bool)
migrated = State -> IntMap (Name, Type, VName, Bool)
stateMigrated State
st
        entry :: (Name, Type, VName, Bool)
entry = (VName -> Name
baseName VName
x, Type
t, VName
arr, Bool
host)
        migrated' :: IntMap (Name, Type, VName, Bool)
migrated' = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, Type, VName, Bool)
entry IntMap (Name, Type, VName, Bool)
migrated
     in State
st {stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = IntMap (Name, Type, VName, Bool)
migrated'}

-- | @pe `migratedTo` (dev, stms)@ registers that the variable @pe@ in the
-- original program has been migrated to @dev@ and rebinds the variable if
-- deemed necessary, adding an index statement to the given statements.
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem Type
pe (VName
dev, Stms GPU
stms) = do
  Bool
used <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
usedOnHost forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
  if Bool
used
    then forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`aliasedBy` VName
dev forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (VName -> Exp GPU
eIndex VName
dev))
    else forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`movedTo` VName
dev forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms

-- | @useScalar stms n@ returns a variable that binds the result bound by @n@
-- in the original program. If the variable has been migrated to device and have
-- not been copied back to host a new variable binding will be added to the
-- provided statements and be returned.
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
  Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
  case Maybe (Name, Type, VName, Bool)
entry of
    Maybe (Name, Type, VName, Bool)
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
    Just (Name
_, Type
_, VName
_, Bool
True) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
    Just (Name
name, Type
t, VName
arr, Bool
_) ->
      do
        VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
        let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (VName -> Exp GPU
eIndex VName
arr)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, VName
n')

-- | Create an expression that reads the first element of a 1-dimensional array.
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

-- | A shorthand for binding a single variable to an expression.
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ())

-- | Returns the array alias of @se@ if it is a variable that has been migrated
-- to device. Otherwise returns @Nothing@.
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
  Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
_, Type
_, VName
arr, Bool
_) -> VName
arr) Maybe (Name, Type, VName, Bool)
entry

-- | @storeScalar stms se t@ returns a variable that binds a single element
-- array that contains the value of @se@ in the original program. If @se@ is a
-- variable that has been migrated to device, its existing array alias will be
-- used. Otherwise a new variable binding will be added to the provided
-- statements and be returned. @t@ is the type of @se@.
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se Type
t = do
  Maybe (Name, Type, VName, Bool)
entry <- case SubExp
se of
    Var VName
n -> forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
    SubExp
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  case Maybe (Name, Type, VName, Bool)
entry of
    Just (Name
_, Type
_, VName
arr, Bool
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
    Maybe (Name, Type, VName, Bool)
Nothing -> do
      -- How to most efficiently create an array containing the given value
      -- depends on whether it is a variable or a constant. Creating a constant
      -- array is a runtime copy of static memory, while creating an array that
      -- contains a variable results in a synchronous write. The latter is thus
      -- replaced with either a mergeable GPUBody kernel or a Replicate.
      --
      -- Whether it makes sense to hoist arrays out of bodies to enable CSE is
      -- left to the simplifier to figure out. Duplicates will be eliminated if
      -- a scalar is stored multiple times within a body.
      --
      -- TODO: Enable the simplifier to hoist non-consumed, non-returned arrays
      --       out of top-level function definitions. All constant arrays
      --       produced here are in principle top-level hoistable.
      Bool
gpubody_ok <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
      case SubExp
se of
        Var VName
n | Bool
gpubody_ok -> do
          VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
          let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)

          Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPU
stm)
          let dev :: VName
dev = forall dec. PatElem dec -> VName
patElemName forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)

          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody, VName
dev)
        Var VName
n -> do
          PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
          let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
          let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
        SubExp
_ -> do
          let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
          PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
          let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)

-- | Map a variable name to itself or, if the variable no longer can be used on
-- host, the name of a single element array containing its value.
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
  Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
  case Maybe (Name, Type, VName, Bool)
entry of
    Maybe (Name, Type, VName, Bool)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
    Just (Name
_, Type
_, VName
_, Bool
True) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
    Just (Name
_, Type
_, VName
arr, Bool
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr

-- | Like 'resolveName' but for a t'SubExp'. Constants are mapped to themselves.
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst

-- | Like 'resolveSubExp' but for a 'SubExpRes'.
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
  -- Certificates are always read back to host.
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se

-- | Apply 'resolveSubExpRes' to a list of results.
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes

-- | Migrate a statement to device, ensuring all its bound variables used on
-- host will remain available with the same names.
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] Type
t')))
  | Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
      do
        -- Save an 'Index' by rewriting the 'ArrayLit' rather than migrating it.
        let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
        let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
        let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
        let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux (ExpDec GPU)
aux forall {k} {rep :: k}. Exp rep
e'

        Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Pat (LetDec GPU)
pat})
moveStm Stms GPU
out Stm GPU
stm = do
  -- Move the statement to device.
  Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)

  -- Read non-scalars and scalars that are used on host.
  let arrs :: [(PatElem Type, PatElem Type)]
arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
  forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody) [(PatElem Type, PatElem Type)]
arrs
  where
    addRead :: Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
_ Type
t), PatElem VName
dev Type
dev_t) =
      let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPU
stms forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe Exp GPU
e
          add :: BasicOp -> ReduceM (Stms GPU)
add = forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp
       in case forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
dev_t of
            -- Alias non-arrays with their prior name.
            Int
0 -> BasicOp -> ReduceM (Stms GPU)
add forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
            -- Read all certificates for free.
            Int
1 | Type
t forall a. Eq a => a -> a -> Bool
== forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
            -- Record the device alias of each scalar variable and read them
            -- if used on host.
            Int
1 -> PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
            -- Drop the added dimension of multidimensional arrays.
            Int
_ -> BasicOp -> ReduceM (Stms GPU)
add forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
dev_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

-- | Create a GPUBody kernel that executes a single statement and stores its
-- results in single element arrays.
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
  (Stm GPU
stm, RState
st) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
  let prologue :: Stms GPU
prologue = RState -> Stms GPU
rewritePrologue RState
st

  let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
  Pat Type
pat <- forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem [PatElem Type]
pes
  let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
  let types :: [Type]
types = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes
  let res :: Result
res = forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
  let body :: Body GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
  let e :: Exp GPU
e = forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types Body GPU
body)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux Exp GPU
e)

--------------------------------------------------------------------------------
--                          KERNEL REWRITING - TYPES                          --
--------------------------------------------------------------------------------

-- The monad used to rewrite (migrated) kernel code.
type RewriteM = StateT RState ReduceM

-- | The state used by a 'RewriteM' monad.
data RState = RState
  { -- | Maps variables in the original program to names to be used by rewrites.
    RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
    -- | Statements to be added as a prologue before rewritten statements.
    RState -> Stms GPU
rewritePrologue :: Stms GPU
  }

--------------------------------------------------------------------------------
--                        KERNEL REWRITING - FUNCTIONS                        --
--------------------------------------------------------------------------------

-- | An initial state to use when running a 'RewriteM' monad.
initialRState :: RState
initialRState :: RState
initialRState =
  RState
    { rewriteRenames :: IntMap VName
rewriteRenames = forall a. Monoid a => a
mempty,
      rewritePrologue :: Stms GPU
rewritePrologue = forall a. Monoid a => a
mempty
    }

-- | Rewrite 'SegBinOp' dependencies to scalars that have been migrated.
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
  Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU
op {segBinOpLambda :: Lambda GPU
segBinOpLambda = Lambda GPU
f'})

-- | Rewrite 'HistOp' dependencies to scalars that have been migrated.
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
  Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPU
op)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
op {histOp :: Lambda GPU
histOp = Lambda GPU
f'})

-- | Rewrite generic lambda dependencies to scalars that have been migrated.
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
  Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
f {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})

-- | Rewrite generic body dependencies to scalars that have been migrated.
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
  (Body GPU
body', Stms GPU
prologue) <- forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
body' {bodyStms :: Stms GPU
bodyStms = Stms GPU
prologue forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body'}

-- | Rewrite kernel body dependencies to scalars that have been migrated.
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
  (KernelBody GPU
kbody', Stms GPU
prologue) <- forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPU
kbody' {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
prologue forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody'}

-- | Rewrite migrated scalar dependencies within anything. The returned
-- statements must be added to the scope of the rewritten construct.
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
  let from :: [VName]
from = Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn a
x)
  ([VName]
to, RState
st) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RewriteM VName
rename [VName]
from) RState
initialRState
  let rename_map :: Map VName VName
rename_map = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
rename_map a
x, RState -> Stms GPU
rewritePrologue RState
st)

-- | Create a fresh name, registering which name it is a rewrite of.
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> RewriteM VName
rewriteName VName
n = do
  VName
n' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n)
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
st -> RState
st {rewriteRenames :: IntMap VName
rewriteRenames = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) VName
n' (RState -> IntMap VName
rewriteRenames RState
st)}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'

-- | Rewrite all bindings introduced by a body (to ensure they are unique) and
-- fix any dependencies that are broken as a result of migration or rewriting.
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
  Result
res' <- Result -> RewriteM Result
renameResult Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')

-- | Rewrite all bindings introduced by a sequence of statements (to ensure they
-- are unique) and fix any dependencies that are broken as a result of migration
-- or rewriting.
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo forall a. Monoid a => a
mempty
  where
    rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
      Stm GPU
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm' of
        Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
stms Result
res)) ->
          let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
           in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd (Stms GPU
out forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes Result
res)
        Exp GPU
_ -> Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'

    bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
    bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd Stms GPU
out (PatElem Type
pe, SubExpRes Certs
cs SubExp
se)
      | Just Type
t' <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (forall t. Typed t => t -> Type
typeOf PatElem Type
pe) =
          Stms GPU
out forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t')
      | Bool
otherwise =
          Stms GPU
out forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)

-- | Rewrite all bindings introduced by a single statement (to ensure they are
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
--
-- NOTE: GPUBody kernels must be rewritten through 'rewriteStms'.
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
  Pat Type
pat' <- Pat Type -> RewriteM (Pat Type)
rewritePat Pat (LetDec GPU)
pat
  StmAux ()
aux' <- StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux (ExpDec GPU)
aux
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux ()
aux' Exp GPU
e')

-- | Rewrite all bindings introduced by a pattern (to ensure they are unique)
-- and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat)

-- | Rewrite the binding introduced by a single pattern element (to ensure it is
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) = do
  VName
n' <- VName -> RewriteM VName
rewriteName VName
n
  Type
t' <- forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType Type
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t')

-- | Fix any 'StmAux' certificate references that are broken as a result of
-- migration or rewriting.
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux (StmAux Certs
certs Attrs
attrs ()
_) = do
  Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
certs' Attrs
attrs ())

-- | Rewrite the bindings introduced by an expression (to ensure they are
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
  forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM forall a b. (a -> b) -> a -> b
$
    Mapper
      { mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
        mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
        mapOnVName :: VName -> RewriteM VName
mapOnVName = VName -> RewriteM VName
rename,
        mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
        mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
        mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
        mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
        mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = forall a b. a -> b -> a
const forall {a}. a
opError
      }
  where
    -- This indicates that something fundamentally is wrong with the migration
    -- table produced by the "Futhark.Analysis.MigrationTable" module.
    opError :: a
opError = forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."

-- | Rewrite the binding introduced by a single parameter (to ensure it is
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
  VName
n' <- VName -> RewriteM VName
rewriteName VName
n
  TypeBase (ShapeBase SubExp) u
t' <- forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) u
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n' TypeBase (ShapeBase SubExp) u
t')

-- | Return the name to use for a rewritten dependency.
rename :: VName -> RewriteM VName
rename :: VName -> RewriteM VName
rename VName
n = do
  RState
st <- forall s (m :: * -> *). MonadState s m => m s
get
  let renames :: IntMap VName
renames = RState -> IntMap VName
rewriteRenames RState
st
  let idx :: Int
idx = VName -> Int
baseTag VName
n
  case forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
idx IntMap VName
renames of
    Just VName
n' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
    Maybe VName
_ ->
      do
        let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
        (Stms GPU
stms', VName
n') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
st' ->
          RState
st'
            { rewriteRenames :: IntMap VName
rewriteRenames = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
idx VName
n' IntMap VName
renames,
              rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
stms'
            }
        forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'

-- | Update the variable names within a 'Result' to account for migration and
-- rewriting.
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> RewriteM SubExpRes
renameSubExpRes

-- | Update the variable names within a 'SubExpRes' to account for migration and
-- rewriting.
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
  Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
  SubExp
se' <- SubExp -> StateT RState ReduceM SubExp
renameSubExp SubExp
se
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs' SubExp
se')

-- | Update the variable names of certificates to account for migration and
-- rewriting.
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RewriteM VName
rename (Certs -> [VName]
unCerts Certs
cs)

-- | Update any variable name within a t'SubExp' to account for migration and
-- rewriting.
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RewriteM VName
rename VName
n
renameSubExp SubExp
se = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

-- | Update the variable names within a type to account for migration and
-- rewriting.
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
-- Note: mapOnType also maps the VName token of accumulators
renameType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp

-- | Update the variable names within an existential type to account for
-- migration and rewriting.
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
-- Note: mapOnExtType also maps the VName token of accumulators
renameExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp