-- | This module implements an optimization that migrates host
-- statements into 'GPUBody' kernels to reduce the number of
-- host-device synchronizations that occur when a scalar variable is
-- written to or read from device memory. Which statements that should
-- be migrated are determined by a 'MigrationTable' produced by the
-- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable" module; this module
-- merely performs the migration and rewriting dictated by that table.
module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where

import Control.Monad
import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.Trans.State.Strict hiding (State)
import Control.Parallel.Strategies (parMap, rpar)
import Data.Foldable
import qualified Data.IntMap.Strict as IM
import Data.List (unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Sequence ((<|), (><), (|>))
import qualified Data.Text as T
import Futhark.Construct (fullSlice, sliceDim)
import Futhark.Error
import qualified Futhark.FreshNames as FN
import Futhark.IR.GPU
import Futhark.MonadFreshNames (VNameSource, getNameSource, putNameSource)
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute

-- | An optimization pass that migrates host statements into 'GPUBody' kernels
-- to reduce the number of host-device synchronizations.
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"reduce device synchronizations"
    String
"Move host statements to device to reduce blocking memory operations."
    Prog GPU -> PassM (Prog GPU)
forall (m :: * -> *). MonadFreshNames m => Prog GPU -> m (Prog GPU)
run
  where
    run :: Prog GPU -> m (Prog GPU)
run Prog GPU
prog = do
      VNameSource
ns <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
      let mt :: MigrationTable
mt = Prog GPU -> MigrationTable
analyseProg Prog GPU
prog
      let st :: State
st = VNameSource -> State
initialState VNameSource
ns
      let (Prog GPU
prog', State
st') = Reader MigrationTable (Prog GPU, State)
-> MigrationTable -> (Prog GPU, State)
forall r a. Reader r a -> r -> a
R.runReader (StateT State (Reader MigrationTable) (Prog GPU)
-> State -> Reader MigrationTable (Prog GPU, State)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
optimizeProgram Prog GPU
prog) State
st) MigrationTable
mt
      VNameSource -> m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (State -> VNameSource
stateNameSource State
st')
      Prog GPU -> m (Prog GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Prog GPU
prog'

--------------------------------------------------------------------------------
--                            AD HOC OPTIMIZATION                             --
--------------------------------------------------------------------------------

-- | Optimize a whole program. The type signatures of top-level functions will
-- remain unchanged.
optimizeProgram :: Prog GPU -> ReduceM (Prog GPU)
optimizeProgram :: Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
optimizeProgram (Prog Stms GPU
consts [FunDef GPU]
funs) = do
  Stms GPU
consts' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
consts
  [FunDef GPU]
funs' <- [StateT State (Reader MigrationTable) (FunDef GPU)]
-> StateT State (Reader MigrationTable) [FunDef GPU]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([StateT State (Reader MigrationTable) (FunDef GPU)]
 -> StateT State (Reader MigrationTable) [FunDef GPU])
-> [StateT State (Reader MigrationTable) (FunDef GPU)]
-> StateT State (Reader MigrationTable) [FunDef GPU]
forall a b. (a -> b) -> a -> b
$ Strategy (StateT State (Reader MigrationTable) (FunDef GPU))
-> (FunDef GPU
    -> StateT State (Reader MigrationTable) (FunDef GPU))
-> [FunDef GPU]
-> [StateT State (Reader MigrationTable) (FunDef GPU)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy (StateT State (Reader MigrationTable) (FunDef GPU))
forall a. Strategy a
rpar FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
optimizeFunDef [FunDef GPU]
funs
  Prog GPU -> StateT State (Reader MigrationTable) (Prog GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> [FunDef GPU] -> Prog GPU
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPU
consts' [FunDef GPU]
funs')

-- | Optimize a function definition. Its type signature will remain unchanged.
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
  let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
  Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
  FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU))
-> FunDef GPU -> StateT State (Reader MigrationTable) (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ FunDef GPU
fd {funDefBody :: Body GPU
funDefBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}

-- | Optimize a body. Scalar results may be replaced with single-element arrays.
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
  Result
res' <- Result -> ReduceM Result
resolveResult Result
res
  Body GPU -> ReduceM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')

-- | Optimize a sequence of statements.
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = (Stms GPU -> Stm GPU -> ReduceM (Stms GPU))
-> Stms GPU -> Stms GPU -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
forall a. Monoid a => a
mempty

-- | Optimize a single statement, rewriting it into one or more statements to
-- be appended to the provided 'Stms'. Only variables with continued host usage
-- will remain in scope if their statement is migrated.
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
  Bool
move <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
  if Bool
move
    then Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out Stm GPU
stm
    else case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm of
      BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
        | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
            -- Rewrite the Update if its write value has been migrated. Copying
            -- is faster than doing a synchronous write, so we use the device
            -- array even if the value has been made available to the host.
            Maybe VName
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
            case Maybe VName
dev of
              Maybe VName
Nothing -> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
              Just VName
dst -> do
                -- Transform the single element Update into a slice Update.
                let dims :: [DimIndex SubExp]
dims = Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
                let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = Int -> [DimIndex SubExp] -> ([DimIndex SubExp], [DimIndex SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
                let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
                let slice' :: Slice SubExp
slice' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
                let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
                let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = Exp GPU
forall rep. Exp rep
e}

                Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
      BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
        | Pat [PatElem VName
n LetDec GPU
arr_t] <- Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
            -- A Replicate can be rewritten to not require its replication value
            -- to be available on host. If its value is migrated the Replicate
            -- thus needs to be transformed.
            --
            -- If the inner dimension of the replication array is one then the
            -- rewrite can be performed more efficiently than the general case.
            VName
v' <- VName -> ReduceM VName
resolveName VName
v
            let v_kept_on_device :: Bool
v_kept_on_device = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
v'

            Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk

            case Bool
v_kept_on_device of
              Bool
False -> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
              Bool
True
                | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
                  Just TypeBase (ShapeBase SubExp) NoUniqueness
t' <- Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t,
                  Bool
gpubody_ok -> do
                    let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
                    let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t']
                    let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
                    let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall rep. Exp rep
e'

                    -- `gpu { v }` is slightly faster than `replicate 1 v` and
                    -- can merge with the GPUBody that v was computed by.
                    Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
                    Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm})
              Bool
True
                | [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
                    let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
                        stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = Exp GPU
forall rep. Exp rep
e'}
                     in Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
              Bool
True -> do
                VName
n' <- VName -> ReduceM VName
newName VName
n
                -- v_kept_on_device implies that v is a scalar.
                let dims' :: [SubExp]
dims' = [SubExp]
dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
                let arr_t' :: TypeBase (ShapeBase SubExp) NoUniqueness
arr_t' = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
                let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
arr_t']
                let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
                let repl :: Stm GPU
repl = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall rep. Exp rep
e'

                let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
                let slice :: [DimIndex SubExp]
slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
LetDec GPU
arr_t)
                let slice' :: [DimIndex SubExp]
slice' = [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
                let idx :: Exp rep
idx = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
                let index :: Stm GPU
index = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
idx

                Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
repl Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
index)
      BasicOp {} ->
        Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
      Apply {} ->
        Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
      If SubExp
cond (Body BodyDec GPU
_ Stms GPU
tstms0 Result
tres) (Body BodyDec GPU
_ Stms GPU
fstms0 Result
fres) (IfDec [BranchType GPU]
btypes IfSort
sort) ->
        do
          -- Rewrite branches.
          Stms GPU
tstms1 <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
tstms0
          Stms GPU
fstms1 <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
fstms0

          -- Ensure return values and types match if one or both branches
          -- return a result that now reside on device.
          let bmerge :: ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
   SubExpRes, TypeBase ExtShape NoUniqueness)],
 Stms GPU, Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
    SubExpRes, TypeBase ExtShape NoUniqueness)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
        SubExpRes, TypeBase ExtShape NoUniqueness)],
      Stms GPU, Stms GPU)
bmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms, Stms GPU
fstms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes
tr, SubExpRes
fr, TypeBase ExtShape NoUniqueness
bt) =
                do
                  let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool) -> ReduceM VName -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
                      onHost SubExp
_ = Bool -> ReduceM Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

                  Bool
tr_on_host <- SubExp -> ReduceM Bool
onHost (SubExpRes -> SubExp
resSubExp SubExpRes
tr)
                  Bool
fr_on_host <- SubExp -> ReduceM Bool
onHost (SubExpRes -> SubExp
resSubExp SubExpRes
fr)

                  if Bool
tr_on_host Bool -> Bool -> Bool
&& Bool
fr_on_host
                    then -- No result resides on device ==> nothing to do.
                      ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
   SubExpRes, TypeBase ExtShape NoUniqueness)],
 Stms GPU, Stms GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
        SubExpRes, TypeBase ExtShape NoUniqueness)],
      Stms GPU, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes
tr, SubExpRes
fr, TypeBase ExtShape NoUniqueness
bt) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
 SubExpRes, TypeBase ExtShape NoUniqueness)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms, Stms GPU
fstms)
                    else -- Otherwise, ensure both results are migrated.
                    do
                      let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
                      (Stms GPU
tstms', VName
tarr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
tstms (SubExpRes -> SubExp
resSubExp SubExpRes
tr) TypeBase (ShapeBase SubExp) NoUniqueness
t
                      (Stms GPU
fstms', VName
farr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
fstms (SubExpRes -> SubExp
resSubExp SubExpRes
fr) TypeBase (ShapeBase SubExp) NoUniqueness
t

                      PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
                      let bt' :: TypeBase ExtShape NoUniqueness
bt' = TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase ExtShape NoUniqueness
forall u. TypeBase (ShapeBase SubExp) u -> TypeBase ExtShape u
staticShapes1 (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe')
                      let tr' :: SubExpRes
tr' = SubExpRes
tr {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
tarr}
                      let fr' :: SubExpRes
fr' = SubExpRes
fr {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
farr}
                      ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
   SubExpRes, TypeBase ExtShape NoUniqueness)],
 Stms GPU, Stms GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
        SubExpRes, TypeBase ExtShape NoUniqueness)],
      Stms GPU, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe', SubExpRes
tr', SubExpRes
fr', TypeBase ExtShape NoUniqueness
bt') (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
 SubExpRes, TypeBase ExtShape NoUniqueness)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
res, Stms GPU
tstms', Stms GPU
fstms')

          let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
          let zipped :: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Result
-> Result
-> [TypeBase ExtShape NoUniqueness]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Result
tres Result
fres [TypeBase ExtShape NoUniqueness]
[BranchType GPU]
btypes
          ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped', Stms GPU
tstms2, Stms GPU
fstms2) <- (([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
    SubExpRes, TypeBase ExtShape NoUniqueness)],
  Stms GPU, Stms GPU)
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)
 -> StateT
      State
      (Reader MigrationTable)
      ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
         SubExpRes, TypeBase ExtShape NoUniqueness)],
       Stms GPU, Stms GPU))
-> ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
      SubExpRes, SubExpRes, TypeBase ExtShape NoUniqueness)],
    Stms GPU, Stms GPU)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
        SubExpRes, TypeBase ExtShape NoUniqueness)],
      Stms GPU, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
   SubExpRes, TypeBase ExtShape NoUniqueness)],
 Stms GPU, Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
    SubExpRes, TypeBase ExtShape NoUniqueness)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
        SubExpRes, TypeBase ExtShape NoUniqueness)],
      Stms GPU, Stms GPU)
bmerge ([], Stms GPU
tstms1, Stms GPU
fstms1) [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped
          let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes', Result
tres', Result
fres', [TypeBase ExtShape NoUniqueness]
btypes') = [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)], Result,
    Result, [TypeBase ExtShape NoUniqueness])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
     SubExpRes, TypeBase ExtShape NoUniqueness)]
forall a. [a] -> [a]
reverse [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes,
  SubExpRes, TypeBase ExtShape NoUniqueness)]
zipped')

          -- Rewrite statement.
          let tbranch' :: Body GPU
tbranch' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
tstms2 Result
tres'
          let fbranch' :: Body GPU
fbranch' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
fstms2 Result
fres'
          let e' :: Exp GPU
e' = SubExp -> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> Exp GPU
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond Body GPU
tbranch' Body GPU
fbranch' ([TypeBase ExtShape NoUniqueness]
-> IfSort -> IfDec (TypeBase ExtShape NoUniqueness)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [TypeBase ExtShape NoUniqueness]
btypes' IfSort
sort)
          let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

          -- Read migrated scalars that are used on host.
          (Stms GPU
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
      DoLoop [(FParam GPU, SubExp)]
ps LoopForm GPU
lf Body GPU
b -> do
        -- Enable the migration of for-in loop variables.
        ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params, LoopForm GPU
lform, Body GPU
body) <- ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn ([(FParam GPU, SubExp)]
ps, LoopForm GPU
lf, Body GPU
b)

        -- Update statement bound variables and parameters if their values
        -- have been migrated to device.
        let lmerge :: ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
   (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
 Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
    MigrationStatus)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param Attrs
_ VName
pn TypeBase (ShapeBase SubExp) Uniqueness
pt, SubExp
pval), MigrationStatus
MoveToDevice) = do
              -- Rewrite the bound variable.
              PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe

              -- Move the initial value to device if not already there.
              (Stms GPU
stms', VName
arr) <- Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
pval (TypeBase (ShapeBase SubExp) Uniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl TypeBase (ShapeBase SubExp) Uniqueness
pt)

              -- Rewrite the parameter.
              VName
pn' <- VName -> ReduceM VName
newName VName
pn
              let pt' :: TypeBase (ShapeBase SubExp) Uniqueness
pt' = TypeBase (ShapeBase SubExp) NoUniqueness
-> Uniqueness -> TypeBase (ShapeBase SubExp) Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe') Uniqueness
Nonunique
              let pval' :: SubExp
pval' = VName -> SubExp
Var VName
arr
              let param' :: (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param' = (Attrs
-> VName
-> TypeBase (ShapeBase SubExp) Uniqueness
-> Param (TypeBase (ShapeBase SubExp) Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pn' TypeBase (ShapeBase SubExp) Uniqueness
pt', SubExp
pval')

              -- Record the migration.
              VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident VName
pn (TypeBase (ShapeBase SubExp) Uniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl TypeBase (ShapeBase SubExp) Uniqueness
pt) Ident -> VName -> ReduceM ()
`movedTo` VName
pn'

              ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
   (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
 Stms GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe', (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param') (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
 (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms')
            lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
   (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
 Stms GPU)
_ (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
_, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
_, MigrationStatus
UsedOnHost) =
              -- Initial loop parameter value and loop result should have
              -- been made available on host instead.
              String
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
forall a. String -> a
compilerBugS String
"optimizeStm: unhandled host usage of loop param"
            lmerge ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param, MigrationStatus
StayOnHost) =
              ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
   (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
 Stms GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
param) (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
 (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
res, Stms GPU
stms)

        MigrationTable
mt <- ReduceM MigrationTable
ask

        let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
        let mss :: [MigrationStatus]
mss = ((Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)
 -> MigrationStatus)
-> [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
-> [MigrationStatus]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n TypeBase (ShapeBase SubExp) Uniqueness
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params
        ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
zipped', Stms GPU
out') <- (([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
  Stms GPU)
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
     MigrationStatus)
 -> StateT
      State
      (Reader MigrationTable)
      ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
         (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
       Stms GPU))
-> ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
      (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
    Stms GPU)
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
     MigrationStatus)]
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
   (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
 Stms GPU)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
    MigrationStatus)
-> StateT
     State
     (Reader MigrationTable)
     ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
        (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))],
      Stms GPU)
lmerge ([], Stms GPU
out) ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
-> [MigrationStatus]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp),
     MigrationStatus)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params [MigrationStatus]
mss)
        let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes', [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
params') = [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)],
    [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
forall a. [a] -> [a]
reverse [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  (Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp))]
zipped')

        -- Rewrite statement.
        Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
optimizeBody Body GPU
body
        let e' :: Exp GPU
e' = [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
params' LoopForm GPU
lform Body GPU
body'
        let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

        -- Read migrated scalars that are used on host.
        (Stms GPU
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out' Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
      WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
        let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ u
_) = VName
a
            getAcc TypeBase shape u
_ =
              String -> VName
forall a. String -> a
compilerBugS
                String
"Type error: WithAcc expression did not return accumulator."

        let accs :: [(VName, WithAccInput GPU)]
accs = (TypeBase (ShapeBase SubExp) NoUniqueness
 -> WithAccInput GPU -> (VName, WithAccInput GPU))
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [WithAccInput GPU]
-> [(VName, WithAccInput GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\TypeBase (ShapeBase SubExp) NoUniqueness
t WithAccInput GPU
i -> (TypeBase (ShapeBase SubExp) NoUniqueness -> VName
forall shape u. TypeBase shape u -> VName
getAcc TypeBase (ShapeBase SubExp) NoUniqueness
t, WithAccInput GPU
i)) (Lambda GPU -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
        [WithAccInput GPU]
inputs' <- ((VName, WithAccInput GPU)
 -> StateT State (Reader MigrationTable) (WithAccInput GPU))
-> [(VName, WithAccInput GPU)]
-> StateT State (Reader MigrationTable) [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName
 -> WithAccInput GPU
 -> StateT State (Reader MigrationTable) (WithAccInput GPU))
-> (VName, WithAccInput GPU)
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName
-> WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs

        let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
        Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)

        let rewrite :: (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
 PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     State
     (Reader MigrationTable)
     (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
      PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewrite (SubExpRes Certs
certs SubExp
se, TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe) =
              do
                SubExp
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
                if SubExp
se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se'
                  then (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
 PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     State
     (Reader MigrationTable)
     (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
      PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se, TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
                  else do
                    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe' <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe
                    let t' :: TypeBase (ShapeBase SubExp) NoUniqueness
t' = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe'
                    (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
 PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     State
     (Reader MigrationTable)
     (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
      PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se', TypeBase (ShapeBase SubExp) NoUniqueness
t', PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe')

        -- Rewrite non-accumulator results that have been migrated.
        --
        -- Accumulator return values do not map to arrays one-to-one but
        -- one-to-many. They are not transformed however and can be mapped
        -- as a no-op.
        let len :: Int
len = [WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
        let (Result
res0, Result
res1) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
        let ([TypeBase (ShapeBase SubExp) NoUniqueness]
rts0, [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1) = Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ([TypeBase (ShapeBase SubExp) NoUniqueness],
    [TypeBase (ShapeBase SubExp) NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Lambda GPU -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPU
lmd)
        let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
        let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes0, [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1) = Int
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)],
    [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res1) [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
        (Result
res1', [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1', [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1') <- [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> (Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
    [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
   PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
 -> (Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
     [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]))
-> StateT
     State
     (Reader MigrationTable)
     [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
       PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> StateT
     State
     (Reader MigrationTable)
     (Result, [TypeBase (ShapeBase SubExp) NoUniqueness],
      [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> StateT
      State
      (Reader MigrationTable)
      (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
       PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> StateT
     State
     (Reader MigrationTable)
     [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
       PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
 PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     State
     (Reader MigrationTable)
     (SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
      PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewrite (Result
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(SubExpRes, TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res1 [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1)
        let res' :: Result
res' = Result
res0 Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res1'
        let rts' :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rts' = [TypeBase (ShapeBase SubExp) NoUniqueness]
rts0 [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase (ShapeBase SubExp) NoUniqueness]
rts1'
        let pes' :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes0 [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes1'

        -- Rewrite statement.
        let body' :: Body GPU
body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
        let lmd' :: Lambda GPU
lmd' = Lambda GPU
lmd {lambdaBody :: Body GPU
lambdaBody = Body GPU
body', lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
rts'}
        let e' :: Exp GPU
e' = [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
        let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'

        -- Read migrated scalars that are used on host.
        (Stms GPU
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
forall dec.
Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem dec)
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes')
      Op Op GPU
op -> do
        HostOp GPU (SOAC GPU)
op' <- HostOp GPU (SOAC GPU) -> ReduceM (HostOp GPU (SOAC GPU))
forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp Op GPU
HostOp GPU (SOAC GPU)
op
        Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm {stmExp :: Exp GPU
stmExp = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op Op GPU
HostOp GPU (SOAC GPU)
op'})
  where
    addRead :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem dec)
-> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe@(PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
_), PatElem VName
dev dec
_)
      | VName
n VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dev = Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
      | Bool
otherwise = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)

-- | Rewrite a for-in loop such that relevant source array reads can be delayed.
rewriteForIn ::
  ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) ->
  ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn loop :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop@([(FParam GPU, SubExp)]
_, WhileLoop {}, Body GPU
_) =
  ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
 LoopForm GPU, Body GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
      LoopForm GPU, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
 LoopForm GPU, Body GPU)
([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop
rewriteForIn ([(FParam GPU, SubExp)]
params, ForLoop VName
i IntType
t SubExp
n [(LParam GPU, VName)]
elems, Body GPU
body) = do
  MigrationTable
mt <- ReduceM MigrationTable
ask
  let ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
elems', Stms GPU
stms') = ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)
 -> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
     Stms GPU)
 -> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
     Stms GPU))
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
    Stms GPU)
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
    Stms GPU)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (MigrationTable
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
    Stms GPU)
-> ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)],
    Stms GPU)
forall dec.
Typed dec =>
MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt) ([], Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
[(LParam GPU, VName)]
elems
  ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
 LoopForm GPU, Body GPU)
-> StateT
     State
     (Reader MigrationTable)
     ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)],
      LoopForm GPU, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param (TypeBase (ShapeBase SubExp) Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
params, VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
t SubExp
n [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), VName)]
[(LParam GPU, VName)]
elems', Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'})
  where
    inline :: MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt (Param dec
x, VName
arr) ([(Param dec, VName)]
arrs, Stms GPU
stms)
      | VName
pn <- Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
x,
        Bool -> Bool
not (VName -> MigrationTable -> Bool
usedOnHost VName
pn MigrationTable
mt) =
          let pt :: TypeBase (ShapeBase SubExp) NoUniqueness
pt = Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
forall t. Typed t => t -> TypeBase (ShapeBase SubExp) NoUniqueness
typeOf Param dec
x
              stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pn TypeBase (ShapeBase SubExp) NoUniqueness
pt) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
forall u. VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) NoUniqueness
pt)
           in ([(Param dec, VName)]
arrs, Stm GPU
stm Stm GPU -> Stms GPU -> Stms GPU
forall a. a -> Seq a -> Seq a
<| Stms GPU
stms)
      | Bool
otherwise =
          ((Param dec
x, VName
arr) (Param dec, VName) -> [(Param dec, VName)] -> [(Param dec, VName)]
forall a. a -> [a] -> [a]
: [(Param dec, VName)]
arrs, Stms GPU
stms)

    index :: VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) u
of_type =
      VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase (ShapeBase SubExp) u -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) u
of_type)

-- | Optimize an accumulator input. The 'VName' is the accumulator token.
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName
-> WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
  Bool
device_only <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
  if Bool
device_only
    then do
      Lambda GPU
op' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
op
      WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
    else do
      let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op
      -- To pass type check neither parameters nor results can change.
      --
      -- op may be used on both host and device so we must avoid introducing
      -- any GPUBody statements.
      Stms GPU
stms' <- ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a. ReduceM a -> ReduceM a
noGPUBody (ReduceM (Stms GPU) -> ReduceM (Stms GPU))
-> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
      let op' :: Lambda GPU
op' = Lambda GPU
op {lambdaBody :: Body GPU
lambdaBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
      WithAccInput GPU
-> StateT State (Reader MigrationTable) (WithAccInput GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))

-- | Optimize a host operation. 'Index' statements are added to kernel code
-- that depends on migrated scalars.
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) =
  SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
  [SegBinOp GPU]
ops' <- (SegBinOp GPU
 -> StateT State (Reader MigrationTable) (SegBinOp GPU))
-> [SegBinOp GPU]
-> StateT State (Reader MigrationTable) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
  SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [SegBinOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
  [SegBinOp GPU]
ops' <- (SegBinOp GPU
 -> StateT State (Reader MigrationTable) (SegBinOp GPU))
-> [SegBinOp GPU]
-> StateT State (Reader MigrationTable) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
  SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [SegBinOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
types KernelBody GPU
kbody)) = do
  [HistOp GPU]
ops' <- (HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU))
-> [HistOp GPU]
-> StateT State (Reader MigrationTable) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
  SegOp SegLevel GPU -> HostOp GPU op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU op)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU op
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace
-> [HistOp GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
types (KernelBody GPU -> HostOp GPU op)
-> StateT State (Reader MigrationTable) (KernelBody GPU)
-> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SizeOp SizeOp
op) =
  HostOp GPU op -> ReduceM (HostOp GPU op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp GPU op
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
  -- These should all have been taken care of in the unstreamGPU pass.
  String -> ReduceM (HostOp GPU op)
forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types Body GPU
body) =
  [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU -> HostOp GPU op
forall rep op.
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep -> HostOp rep op
GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types (Body GPU -> HostOp GPU op)
-> ReduceM (Body GPU) -> ReduceM (HostOp GPU op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body

--------------------------------------------------------------------------------
--                               COMMON HELPERS                               --
--------------------------------------------------------------------------------

-- | Append the given string to a name.
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)

--------------------------------------------------------------------------------
--                             MIGRATION - TYPES                              --
--------------------------------------------------------------------------------

-- | The monad used to perform migration-based synchronization reductions.
type ReduceM = StateT State (R.Reader MigrationTable)

-- | The state used by a 'ReduceM' monad.
data State = State
  { -- | A source to generate new 'VName's from.
    State -> VNameSource
stateNameSource :: VNameSource,
    -- | A table of variables in the original program which have been migrated
    -- to device. Each variable maps to a tuple that describes:
    --   * 'baseName' of the original variable.
    --   * Type of the original variable.
    --   * Name of the single element array holding the migrated value.
    --   * Whether the original variable still can be used on the host.
    State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
    -- | Whether non-migration optimizations may introduce 'GPUBody' kernels at
    -- the current location.
    State -> Bool
stateGPUBodyOk :: Bool
  }

--------------------------------------------------------------------------------
--                           MIGRATION - PRIMITIVES                           --
--------------------------------------------------------------------------------

-- | An initial state to use when running a 'ReduceM' monad.
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
  State :: VNameSource
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Bool
-> State
State
    { stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
      stateMigrated :: IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated = IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Monoid a => a
mempty,
      stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
    }

-- | Retrieve a function of the current environment.
asks :: (MigrationTable -> a) -> ReduceM a
asks :: (MigrationTable -> a) -> ReduceM a
asks = ReaderT MigrationTable Identity a -> ReduceM a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT MigrationTable Identity a -> ReduceM a)
-> ((MigrationTable -> a) -> ReaderT MigrationTable Identity a)
-> (MigrationTable -> a)
-> ReduceM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MigrationTable -> a) -> ReaderT MigrationTable Identity a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks

-- | Fetch the value of the environment.
ask :: ReduceM MigrationTable
ask :: ReduceM MigrationTable
ask = ReaderT MigrationTable Identity MigrationTable
-> ReduceM MigrationTable
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT MigrationTable Identity MigrationTable
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask

-- | Perform non-migration optimizations without introducing any GPUBody
-- kernels.
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
  Bool
prev <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk
  (State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
False}
  a
res <- ReduceM a
m
  (State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
prev}
  a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res

-- | Produce a fresh name, using the given name as a template.
newName :: VName -> ReduceM VName
newName :: VName -> ReduceM VName
newName VName
n = do
  State
st <- StateT State (Reader MigrationTable) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let ns :: VNameSource
ns = State -> VNameSource
stateNameSource State
st
  let (VName
n', VNameSource
ns') = VNameSource -> VName -> (VName, VNameSource)
FN.newName VNameSource
ns VName
n
  State -> ReduceM ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State
st {stateNameSource :: VNameSource
stateNameSource = VNameSource
ns'})
  VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'

-- | Create a 'PatElem' that binds the array of a migrated variable binding.
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t) = do
  let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
  VName
dev <- VName -> ReduceM VName
newName (Name -> Int -> VName
VName Name
name Int
0)
  let dev_t :: TypeBase (ShapeBase SubExp) NoUniqueness
dev_t = TypeBase (ShapeBase SubExp) NoUniqueness
t TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
dev TypeBase (ShapeBase SubExp) NoUniqueness
dev_t)

-- | @x `movedTo` arr@ registers that the value of @x@ has been migrated to
-- @arr[0]@.
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False

-- | @x `aliasedBy` arr@ registers that the value of @x@ also is available on
-- device as @arr[0]@.
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True

-- | @recordMigration host x arr@ records the migration of variable @x@ to
-- @arr[0]@. If @host@ then the original binding can still be used on host.
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x TypeBase (ShapeBase SubExp) NoUniqueness
t) VName
arr =
  (State -> State) -> ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let migrated :: IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated = State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated State
st
        entry :: (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry = (VName -> Name
baseName VName
x, TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, Bool
host)
        migrated' :: IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated' = Int
-> (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated
     in State
st {stateMigrated :: IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated = IntMap
  (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
migrated'}

-- | @pe `migratedTo` (dev, stms)@ registers that the variable @pe@ in the
-- original program has been migrated to @dev@ and rebinds the variable if
-- deemed necessary, adding an index statement to the given statements.
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (VName
dev, Stms GPU
stms) = do
  Bool
used <- (MigrationTable -> Bool) -> ReduceM Bool
forall a. (MigrationTable -> a) -> ReduceM a
asks (VName -> MigrationTable -> Bool
usedOnHost (VName -> MigrationTable -> Bool)
-> VName -> MigrationTable -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
  if Bool
used
    then PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Ident -> VName -> ReduceM ()
`aliasedBy` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (VName -> Exp GPU
eIndex VName
dev))
    else PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Ident -> VName -> ReduceM ()
`movedTo` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms

-- | @useScalar stms n@ returns a variable that binds the result bound by @n@
-- in the original program. If the variable has been migrated to device and have
-- not been copied back to host a new variable binding will be added to the
-- provided statements and be returned.
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
  Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
   (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
 -> Maybe
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (Maybe
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
 -> IntMap
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
  case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
    Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing ->
      (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
    Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
_, Bool
True) ->
      (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
    Just (Name
name, TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, Bool
_) ->
      do
        VName
n' <- VName -> ReduceM VName
newName (Name -> Int -> VName
VName Name
name Int
0)
        let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t) (VName -> Exp GPU
eIndex VName
arr)
        (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, VName
n')

-- | Create an expression that reads the first element of a 1-dimensional array.
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

-- | A shorthand for binding a single variable to an expression.
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ())

-- | Returns the array alias of @se@ if it is a variable that has been migrated
-- to device. Otherwise returns @Nothing@.
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = Maybe VName -> ReduceM (Maybe VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe VName
forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
  Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
   (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
 -> Maybe
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (Maybe
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
 -> IntMap
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
  Maybe VName -> ReduceM (Maybe VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe VName -> ReduceM (Maybe VName))
-> Maybe VName -> ReduceM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ ((Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
 -> VName)
-> Maybe
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> VName
arr) Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry

-- | @storeScalar stms se t@ returns a variable that binds a single element
-- array that contains the value of @se@ in the original program. If @se@ is a
-- variable that has been migrated to device, its existing array alias will be
-- used. Otherwise a new variable binding will be added to the provided
-- statements and be returned. @t@ is the type of @se@.
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU
-> SubExp
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se TypeBase (ShapeBase SubExp) NoUniqueness
t = do
  Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- case SubExp
se of
    Var VName
n -> Int
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
   (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
 -> Maybe
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (Maybe
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
 -> IntMap
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
    SubExp
_ -> Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> StateT
     State
     (Reader MigrationTable)
     (Maybe
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Maybe a
Nothing
  case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
    Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
    Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing -> do
      -- How to most efficiently create an array containing the given value
      -- depends on whether it is a variable or a constant. Creating a constant
      -- array is a runtime copy of static memory, while creating an array that
      -- contains a variable results in a synchronous write. The latter is thus
      -- replaced with either a mergeable GPUBody kernel or a Replicate.
      --
      -- Whether it makes sense to hoist arrays out of bodies to enable CSE is
      -- left to the simplifier to figure out. Duplicates will be eliminated if
      -- a scalar is stored multiple times within a body.
      --
      -- TODO: Enable the simplifier to hoist non-consumed, non-returned arrays
      --       out of top-level function definitions. All constant arrays
      --       produced here are in principle top-level hoistable.
      Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Bool
stateGPUBodyOk
      case SubExp
se of
        Var VName
n | Bool
gpubody_ok -> do
          VName
n' <- VName -> ReduceM VName
newName VName
n
          let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)

          Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPU
stm)
          let dev :: VName
dev = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall a b. (a -> b) -> a -> b
$ [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)

          (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody, VName
dev)
        Var VName
n -> do
          PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t)
          let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
          let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
          (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)
        SubExp
_ -> do
          let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
          PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe <- PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t)
          let stm :: Stm GPU
stm = PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t)
          (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe)

-- | Map a variable name to itself or, if the variable no longer can be used on
-- host, the name of a single element array containing its value.
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
  Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry <- Int
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
-> Maybe
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap
   (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
 -> Maybe
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (Maybe
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State
 -> IntMap
      (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
-> StateT
     State
     (Reader MigrationTable)
     (IntMap
        (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State
-> IntMap
     (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
stateMigrated
  case Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
entry of
    Maybe (Name, TypeBase (ShapeBase SubExp) NoUniqueness, VName, Bool)
Nothing -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
    Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
_, Bool
True) -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
    Just (Name
_, TypeBase (ShapeBase SubExp) NoUniqueness
_, VName
arr, Bool
_) -> VName -> ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr

-- | Like 'resolveName' but for a t'SubExp'. Constants are mapped to themselves.
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp) -> ReduceM VName -> ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = SubExp -> ReduceM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst

-- | Like 'resolveSubExp' but for a 'SubExpRes'.
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
  -- Certificates are always read back to host.
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (SubExp -> SubExpRes) -> ReduceM SubExp -> ReduceM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se

-- | Apply 'resolveSubExpRes' to a list of results.
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = (SubExpRes -> ReduceM SubExpRes) -> Result -> ReduceM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes

-- | Migrate a statement to device, ensuring all its bound variables used on
-- host will remain available with the same names.
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t')))
  | Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
      do
        -- Save an 'Index' by rewriting the 'ArrayLit' rather than migrating it.
        let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
        let pat' :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t']
        let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
        let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
e'

        Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
        Stms GPU -> ReduceM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Pat (LetDec GPU)
pat})
moveStm Stms GPU
out Stm GPU
stm = do
  -- Move the statement to device.
  Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)

  -- Read non-scalars and scalars that are used on host.
  let arrs :: [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
arrs = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
  (Stms GPU
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> ReduceM (Stms GPU))
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody) [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))]
arrs
  where
    addRead :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
    PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe@(PatElem VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
t), PatElem VName
dev TypeBase (ShapeBase SubExp) NoUniqueness
dev_t) =
      let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = Stms GPU -> f (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> f (Stms GPU)) -> Stms GPU -> f (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Exp GPU -> Stm GPU
bind PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe Exp GPU
e
          add :: BasicOp -> ReduceM (Stms GPU)
add = Exp GPU -> ReduceM (Stms GPU)
forall (f :: * -> *). Applicative f => Exp GPU -> f (Stms GPU)
add' (Exp GPU -> ReduceM (Stms GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp
       in case TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
dev_t of
            -- Alias non-arrays with their prior name.
            Int
0 -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
            -- Read all certificates for free.
            Int
1 | TypeBase (ShapeBase SubExp) NoUniqueness
t TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> Exp GPU -> ReduceM (Stms GPU)
forall (f :: * -> *). Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
            -- Record the device alias of each scalar variable and read them
            -- if used on host.
            Int
1 -> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
            -- Drop the added dimension of multidimensional arrays.
            Int
_ -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (TypeBase (ShapeBase SubExp) NoUniqueness
-> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase (ShapeBase SubExp) NoUniqueness
dev_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

-- | Create a GPUBody kernel that executes a single statement and stores its
-- results in single element arrays.
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
  (Stm GPU
stm, RState
st) <- RewriteM (Stm GPU) -> RState -> ReduceM (Stm GPU, RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
  let prologue :: Stms GPU
prologue = RState -> Stms GPU
rewritePrologue RState
st

  let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
  Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <- [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     State
     (Reader MigrationTable)
     [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
     State
     (Reader MigrationTable)
     (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
     State
     (Reader MigrationTable)
     [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
arrayizePatElem [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
  let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
  let types :: [TypeBase (ShapeBase SubExp) NoUniqueness]
types = (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
  let res :: Result
res = (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> SubExpRes)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
forall a. Monoid a => a
mempty (SubExp -> SubExpRes)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> SubExp)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes
  let body :: Body GPU
body = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
  let e :: Exp GPU
e = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU -> HostOp GPU (SOAC GPU)
forall rep op.
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep -> HostOp rep op
GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
types Body GPU
body)
  Stm GPU -> ReduceM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
e)

--------------------------------------------------------------------------------
--                          KERNEL REWRITING - TYPES                          --
--------------------------------------------------------------------------------

-- The monad used to rewrite (migrated) kernel code.
type RewriteM = StateT RState ReduceM

-- | The state used by a 'RewriteM' monad.
data RState = RState
  { -- | Maps variables in the original program to names to be used by rewrites.
    RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
    -- | Statements to be added as a prologue before rewritten statements.
    RState -> Stms GPU
rewritePrologue :: Stms GPU
  }

--------------------------------------------------------------------------------
--                        KERNEL REWRITING - FUNCTIONS                        --
--------------------------------------------------------------------------------

-- | An initial state to use when running a 'RewriteM' monad.
initialRState :: RState
initialRState :: RState
initialRState =
  RState :: IntMap VName -> Stms GPU -> RState
RState
    { rewriteRenames :: IntMap VName
rewriteRenames = IntMap VName
forall a. Monoid a => a
mempty,
      rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
forall a. Monoid a => a
mempty
    }

-- | Rewrite 'SegBinOp' dependencies to scalars that have been migrated.
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
  Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (SegBinOp GPU -> Lambda GPU
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
  SegBinOp GPU -> StateT State (Reader MigrationTable) (SegBinOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU
op {segBinOpLambda :: Lambda GPU
segBinOpLambda = Lambda GPU
f'})

-- | Rewrite 'HistOp' dependencies to scalars that have been migrated.
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
  Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (HistOp GPU -> Lambda GPU
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPU
op)
  HistOp GPU -> StateT State (Reader MigrationTable) (HistOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
op {histOp :: Lambda GPU
histOp = Lambda GPU
f'})

-- | Rewrite generic lambda dependencies to scalars that have been migrated.
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
  Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
  Lambda GPU -> ReduceM (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
f {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})

-- | Rewrite generic body dependencies to scalars that have been migrated.
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
  (Body GPU
body', Stms GPU
prologue) <- Body GPU -> ReduceM (Body GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
  Body GPU -> ReduceM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
body' {bodyStms :: Stms GPU
bodyStms = Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body'}

-- | Rewrite kernel body dependencies to scalars that have been migrated.
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
  (KernelBody GPU
kbody', Stms GPU
prologue) <- KernelBody GPU -> ReduceM (KernelBody GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
  KernelBody GPU
-> StateT State (Reader MigrationTable) (KernelBody GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPU
kbody' {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< KernelBody GPU -> Stms GPU
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody'}

-- | Rewrite migrated scalar dependencies within anything. The returned
-- statements must be added to the scope of the rewritten construct.
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
  let from :: [VName]
from = Names -> [VName]
namesToList (a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
  ([VName]
to, RState
st) <- StateT RState ReduceM [VName]
-> RState -> ReduceM ([VName], RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT RState ReduceM VName
rename [VName]
from) RState
initialRState
  let rename_map :: Map VName VName
rename_map = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
  (a, Stms GPU) -> ReduceM (a, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName -> a -> a
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
rename_map a
x, RState -> Stms GPU
rewritePrologue RState
st)

-- | Create a fresh name, registering which name it is a rewrite of.
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT RState ReduceM VName
rewriteName VName
n = do
  VName
n' <- ReduceM VName -> StateT RState ReduceM VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> ReduceM VName
newName VName
n)
  (RState -> RState) -> StateT RState ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st -> RState
st {rewriteRenames :: IntMap VName
rewriteRenames = Int -> VName -> IntMap VName -> IntMap VName
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) VName
n' (RState -> IntMap VName
rewriteRenames RState
st)}
  VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'

-- | Rewrite all bindings introduced by a body (to ensure they are unique) and
-- fix any dependencies that are broken as a result of migration or rewriting.
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
  Result
res' <- Result -> RewriteM Result
renameResult Result
res
  Body GPU -> RewriteM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')

-- | Rewrite all bindings introduced by a sequence of statements (to ensure they
-- are unique) and fix any dependencies that are broken as a result of migration
-- or rewriting.
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = (Stms GPU -> Stm GPU -> RewriteM (Stms GPU))
-> Stms GPU -> Stms GPU -> RewriteM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
forall a. Monoid a => a
mempty
  where
    rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
      Stm GPU
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
      Stms GPU -> RewriteM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> RewriteM (Stms GPU))
-> Stms GPU -> RewriteM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm' of
        Op (GPUBody _ (Body _ stms res)) ->
          let pes :: [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes = Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
           in (Stms GPU
 -> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
 -> Stms GPU)
-> Stms GPU
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     SubExpRes)]
-> Stms GPU
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
-> Stms GPU
bnd (Stms GPU
out Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Result
-> [(PatElem (TypeBase (ShapeBase SubExp) NoUniqueness),
     SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
pes Result
res)
        Exp GPU
_ -> Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'

    bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
    bnd :: Stms GPU
-> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness), SubExpRes)
-> Stms GPU
bnd Stms GPU
out (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe, SubExpRes Certs
cs SubExp
se)
      | Just TypeBase (ShapeBase SubExp) NoUniqueness
t' <- Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall t. Typed t => t -> TypeBase (ShapeBase SubExp) NoUniqueness
typeOf PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe) =
          Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
ArrayLit [SubExp
se] TypeBase (ShapeBase SubExp) NoUniqueness
t')
      | Bool
otherwise =
          Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)

-- | Rewrite all bindings introduced by a single statement (to ensure they are
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
--
-- NOTE: GPUBody kernels must be rewritten through 'rewriteStms'.
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
  Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat' <- Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePat Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat
  StmAux ()
aux' <- StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux ()
StmAux (ExpDec GPU)
aux
  Stm GPU -> RewriteM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
Pat (LetDec GPU)
pat' StmAux ()
StmAux (ExpDec GPU)
aux' Exp GPU
e')

-- | Rewrite all bindings introduced by a pattern (to ensure they are unique)
-- and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePat Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat = [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> StateT
     RState ReduceM [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> RewriteM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StateT
      RState
      ReduceM
      (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> StateT
     RState ReduceM [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
     RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePatElem (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat)

-- | Rewrite the binding introduced by a single pattern element (to ensure it is
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
     RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
rewritePatElem (PatElem VName
n TypeBase (ShapeBase SubExp) NoUniqueness
t) = do
  VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
  TypeBase (ShapeBase SubExp) NoUniqueness
t' <- TypeBase (ShapeBase SubExp) NoUniqueness
-> RewriteM (TypeBase (ShapeBase SubExp) NoUniqueness)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) NoUniqueness
t
  PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StateT
     RState ReduceM (PatElem (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElem (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' TypeBase (ShapeBase SubExp) NoUniqueness
t')

-- | Fix any 'StmAux' certificate references that are broken as a result of
-- migration or rewriting.
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux (StmAux Certs
certs Attrs
attrs ()
_) = do
  Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
  StmAux () -> RewriteM (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
certs' Attrs
attrs ())

-- | Rewrite the bindings introduced by an expression (to ensure they are
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
  Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Mapper GPU GPU (StateT RState ReduceM)
 -> Exp GPU -> RewriteM (Exp GPU))
-> Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU
-> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$
    Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
      { mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
        mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
        mapOnVName :: VName -> StateT RState ReduceM VName
mapOnVName = VName -> StateT RState ReduceM VName
rename,
        mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = RetType GPU -> StateT RState ReduceM (RetType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType,
        mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = BranchType GPU -> StateT RState ReduceM (BranchType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType,
        mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = FParam GPU -> StateT RState ReduceM (FParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
        mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = LParam GPU -> StateT RState ReduceM (LParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
        mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = StateT RState ReduceM (HostOp GPU (SOAC GPU))
-> HostOp GPU (SOAC GPU)
-> StateT RState ReduceM (HostOp GPU (SOAC GPU))
forall a b. a -> b -> a
const StateT RState ReduceM (HostOp GPU (SOAC GPU))
forall a. a
opError
      }
  where
    -- This indicates that something fundamentally is wrong with the migration
    -- table produced by the "Futhark.Analysis.MigrationTable" module.
    opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."

-- | Rewrite the binding introduced by a single parameter (to ensure it is
-- unique) and fix any dependencies that are broken as a result of migration or
-- rewriting.
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
  VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
  TypeBase (ShapeBase SubExp) u
t' <- TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) u
t
  Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs
-> VName
-> TypeBase (ShapeBase SubExp) u
-> Param (TypeBase (ShapeBase SubExp) u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n' TypeBase (ShapeBase SubExp) u
t')

-- | Return the name to use for a rewritten dependency.
rename :: VName -> RewriteM VName
rename :: VName -> StateT RState ReduceM VName
rename VName
n = do
  RState
st <- StateT RState ReduceM RState
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let renames :: IntMap VName
renames = RState -> IntMap VName
rewriteRenames RState
st
  let idx :: Int
idx = VName -> Int
baseTag VName
n
  case Int -> IntMap VName -> Maybe VName
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
idx IntMap VName
renames of
    Just VName
n' -> VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
    Maybe VName
_ ->
      do
        let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
        (Stms GPU
stms', VName
n') <- ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReduceM (Stms GPU, VName)
 -> StateT RState ReduceM (Stms GPU, VName))
-> ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
        (RState -> RState) -> StateT RState ReduceM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st' ->
          RState
st'
            { rewriteRenames :: IntMap VName
rewriteRenames = Int -> VName -> IntMap VName -> IntMap VName
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
idx VName
n' IntMap VName
renames,
              rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
stms'
            }
        VName -> StateT RState ReduceM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'

-- | Update the variable names within a 'Result' to account for migration and
-- rewriting.
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = (SubExpRes -> StateT RState ReduceM SubExpRes)
-> Result -> RewriteM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes

-- | Update the variable names within a 'SubExpRes' to account for migration and
-- rewriting.
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
  Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
  SubExp
se' <- SubExp -> StateT RState ReduceM SubExp
renameSubExp SubExp
se
  SubExpRes -> StateT RState ReduceM SubExpRes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs' SubExp
se')

-- | Update the variable names of certificates to account for migration and
-- rewriting.
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT RState ReduceM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT RState ReduceM VName
rename (Certs -> [VName]
unCerts Certs
cs)

-- | Update any variable name within a t'SubExp' to account for migration and
-- rewriting.
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp)
-> StateT RState ReduceM VName -> StateT RState ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> StateT RState ReduceM VName
rename VName
n
renameSubExp SubExp
se = SubExp -> StateT RState ReduceM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

-- | Update the variable names within a type to account for migration and
-- rewriting.
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
-- Note: mapOnType also maps the VName token of accumulators
renameType :: TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp

-- | Update the variable names within an existential type to account for
-- migration and rewriting.
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
-- Note: mapOnExtType also maps the VName token of accumulators
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase ExtShape u -> m (TypeBase ExtShape u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp