{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.GPU.Simplify
  ( simplifyGPU,
    simplifyLambda,
    GPU,

    -- * Building blocks
    simplifyKernelOp,
  )
where

import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.GPU
import Futhark.IR.SOACS.Simplify qualified as SOAC
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (focusNth)

simpleGPU :: Simplify.SimpleOps GPU
simpleGPU :: SimpleOps GPU
simpleGPU = SimplifyOp GPU (Op (Wise GPU)) -> SimpleOps GPU
forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps (SimplifyOp GPU (Op (Wise GPU)) -> SimpleOps GPU)
-> SimplifyOp GPU (Op (Wise GPU)) -> SimpleOps GPU
forall a b. (a -> b) -> a -> b
$ SimplifyOp GPU (SOAC (Wise GPU))
-> HostOp SOAC (Wise GPU)
-> SimpleM GPU (HostOp SOAC (Wise GPU), Stms (Wise GPU))
forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> HostOp op (Wise rep)
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp SimplifyOp GPU (SOAC (Wise GPU))
forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
SOAC.simplifySOAC

simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU =
  SimpleOps GPU
-> RuleBook (Wise GPU)
-> HoistBlockers GPU
-> Prog GPU
-> PassM (Prog GPU)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules HoistBlockers GPU
forall {k} (rep :: k). HoistBlockers rep
Simplify.noExtraHoistBlockers

simplifyLambda ::
  (HasScope GPU m, MonadFreshNames m) =>
  Lambda GPU ->
  m (Lambda GPU)
simplifyLambda :: forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
simplifyLambda =
  SimpleOps GPU
-> RuleBook (Wise GPU)
-> HoistBlockers GPU
-> Lambda GPU
-> m (Lambda GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules HoistBlockers GPU
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyKernelOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ ()
  ) =>
  Simplify.SimplifyOp rep (op (Wise rep)) ->
  HostOp op (Wise rep) ->
  Engine.SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp :: forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> HostOp op (Wise rep)
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp SimplifyOp rep (op (Wise rep))
f (OtherOp op (Wise rep)
op) = do
  (op (Wise rep)
op', Stms (Wise rep)
stms) <- SimplifyOp rep (op (Wise rep))
f op (Wise rep)
op
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (op (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp op (Wise rep)
op', Stms (Wise rep)
stms)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SegOp SegOp SegLevel (Wise rep)
op) = do
  (SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted) <- SegOp SegLevel (Wise rep)
-> SimpleM rep (SegOp SegLevel (Wise rep), Stms (Wise rep))
forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp SegOp SegLevel (Wise rep)
op
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp SegLevel (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (GetSize Name
key SizeClass
size_class)) =
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op (Wise rep)) -> SizeOp -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
key SizeClass
size_class, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (GetSizeMax SizeClass
size_class)) =
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op (Wise rep)) -> SizeOp -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
size_class, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x)) = do
  SubExp
x' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
x
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op (Wise rep)) -> SizeOp -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
key SizeClass
size_class SubExp
x', Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (CalcNumBlocks SubExp
w Name
max_num_tblocks SubExp
tblock_size)) = do
  SubExp
w' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op (Wise rep)) -> SizeOp -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumBlocks SubExp
w' Name
max_num_tblocks SubExp
tblock_size, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (GPUBody [Type]
ts Body (Wise rep)
body) = do
  [Type]
ts' <- [Type] -> SimpleM rep [Type]
forall rep. SimplifiableRep rep => [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [Type]
ts
  (Stms (Wise rep)
hoisted, Body (Wise rep)
body') <-
    BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
Engine.simplifyBody BlockPred (Wise rep)
forall {p} {p} {rep}. p -> p -> Stm rep -> Bool
keepOnGPU UsageTable
forall a. Monoid a => a
mempty ((Type -> Usages) -> [Type] -> [Usages]
forall a b. (a -> b) -> [a] -> [b]
map (Usages -> Type -> Usages
forall a b. a -> b -> a
const Usages
forall a. Monoid a => a
mempty) [Type]
ts) Body (Wise rep)
body
  (HostOp op (Wise rep), Stms (Wise rep))
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> Body (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts' Body (Wise rep)
body', Stms (Wise rep)
hoisted)
  where
    keepOnGPU :: p -> p -> Stm rep -> Bool
keepOnGPU p
_ p
_ = Exp rep -> Bool
forall {rep}. Exp rep -> Bool
keepExpOnGPU (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp
    keepExpOnGPU :: Exp rep -> Bool
keepExpOnGPU (BasicOp Index {}) = Bool
True
    keepExpOnGPU (BasicOp (ArrayLit [SubExp]
_ Type
t)) | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
t = Bool
True
    keepExpOnGPU Loop {} = Bool
True
    keepExpOnGPU Exp rep
_ = Bool
False

instance TraverseOpStms (Wise GPU) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise GPU)) (Wise GPU)
traverseOpStms = OpStmsTraverser m (SOAC (Wise GPU)) (Wise GPU)
-> OpStmsTraverser m (HostOp SOAC (Wise GPU)) (Wise GPU)
forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms OpStmsTraverser m (SOAC (Wise GPU)) (Wise GPU)
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

instance BuilderOps (Wise GPU)

instance HasSegOp (Wise GPU) where
  type SegOpLevel (Wise GPU) = SegLevel
  asSegOp :: Op (Wise GPU) -> Maybe (SegOp (SegOpLevel (Wise GPU)) (Wise GPU))
asSegOp (SegOp SegOp SegLevel (Wise GPU)
op) = SegOp SegLevel (Wise GPU) -> Maybe (SegOp SegLevel (Wise GPU))
forall a. a -> Maybe a
Just SegOp SegLevel (Wise GPU)
op
  asSegOp Op (Wise GPU)
_ = Maybe (SegOp (SegOpLevel (Wise GPU)) (Wise GPU))
Maybe (SegOp SegLevel (Wise GPU))
forall a. Maybe a
Nothing
  segOp :: SegOp (SegOpLevel (Wise GPU)) (Wise GPU) -> Op (Wise GPU)
segOp = SegOp (SegOpLevel (Wise GPU)) (Wise GPU) -> Op (Wise GPU)
SegOp SegLevel (Wise GPU) -> HostOp SOAC (Wise GPU)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp

instance SOAC.HasSOAC (Wise GPU) where
  asSOAC :: Op (Wise GPU) -> Maybe (SOAC (Wise GPU))
asSOAC (OtherOp SOAC (Wise GPU)
soac) = SOAC (Wise GPU) -> Maybe (SOAC (Wise GPU))
forall a. a -> Maybe a
Just SOAC (Wise GPU)
soac
  asSOAC Op (Wise GPU)
_ = Maybe (SOAC (Wise GPU))
forall a. Maybe a
Nothing
  soacOp :: SOAC (Wise GPU) -> Op (Wise GPU)
soacOp = SOAC (Wise GPU) -> Op (Wise GPU)
SOAC (Wise GPU) -> HostOp SOAC (Wise GPU)
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp

kernelRules :: RuleBook (Wise GPU)
kernelRules :: RuleBook (Wise GPU)
kernelRules =
  RuleBook (Wise GPU)
forall rep. (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules
    RuleBook (Wise GPU) -> RuleBook (Wise GPU) -> RuleBook (Wise GPU)
forall a. Semigroup a => a -> a -> a
<> RuleBook (Wise GPU)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules
    RuleBook (Wise GPU) -> RuleBook (Wise GPU) -> RuleBook (Wise GPU)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise GPU)]
-> [BottomUpRule (Wise GPU)] -> RuleBook (Wise GPU)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyKnownIterationSOAC,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.removeReplicateMapping,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.liftIdentityMapping,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyMapIota,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.removeUnusedSOACInput,
        RuleBasicOp (Wise GPU) (TopDown (Wise GPU))
-> TopDownRule (Wise GPU)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise GPU) (TopDown (Wise GPU))
forall rep. BuilderOps rep => TopDownRuleBasicOp rep
removeScalarCopy
      ]
      [ RuleBasicOp (Wise GPU) (BottomUp (Wise GPU))
-> BottomUpRule (Wise GPU)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise GPU) (BottomUp (Wise GPU))
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
        RuleOp (Wise GPU) (BottomUp (Wise GPU)) -> BottomUpRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (BottomUp (Wise GPU))
removeDeadGPUBodyResult
      ]

-- | Remove the unused return values of a GPUBody.
removeDeadGPUBodyResult :: BottomUpRuleOp (Wise GPU)
removeDeadGPUBodyResult :: RuleOp (Wise GPU) (BottomUp (Wise GPU))
removeDeadGPUBodyResult (TopDown (Wise GPU)
_, UsageTable
used) Pat (LetDec (Wise GPU))
pat StmAux (ExpDec (Wise GPU))
aux (GPUBody [Type]
types Body (Wise GPU)
body)
  | -- Figure out which of the names in 'pat' are used...
    [Bool]
pat_used <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, Type)
Pat (LetDec (Wise GPU))
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
pat_used) =
      -- Remove the parts of the GPUBody results that correspond to dead
      -- return value bindings.  Note that this leaves dead code in the
      -- kernel, but that will be removed later.
      let pick :: [a] -> [a]
          pick :: forall a. [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
pat_used
          pat' :: [PatElem (VarWisdom, Type)]
pat' = [PatElem (VarWisdom, Type)] -> [PatElem (VarWisdom, Type)]
forall a. [a] -> [a]
pick (Pat (VarWisdom, Type) -> [PatElem (VarWisdom, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, Type)
Pat (LetDec (Wise GPU))
pat)
          types' :: [Type]
types' = [Type] -> [Type]
forall a. [a] -> [a]
pick [Type]
types
          body' :: Body (Wise GPU)
body' = Body (Wise GPU)
body {bodyResult = pick (bodyResult body)}
       in RuleM (Wise GPU) () -> Rule (Wise GPU)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise GPU) () -> Rule (Wise GPU))
-> RuleM (Wise GPU) () -> Rule (Wise GPU)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise GPU) () -> RuleM (Wise GPU) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise GPU))
aux (RuleM (Wise GPU) () -> RuleM (Wise GPU) ())
-> RuleM (Wise GPU) () -> RuleM (Wise GPU) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise GPU))))
-> Exp (Rep (RuleM (Wise GPU))) -> RuleM (Wise GPU) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat') (Exp (Rep (RuleM (Wise GPU))) -> RuleM (Wise GPU) ())
-> Exp (Rep (RuleM (Wise GPU))) -> RuleM (Wise GPU) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise GPU))) -> Exp (Rep (RuleM (Wise GPU)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise GPU))) -> Exp (Rep (RuleM (Wise GPU))))
-> Op (Rep (RuleM (Wise GPU))) -> Exp (Rep (RuleM (Wise GPU)))
forall a b. (a -> b) -> a -> b
$ [Type] -> Body (Wise GPU) -> HostOp SOAC (Wise GPU)
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types' Body (Wise GPU)
body'
  | Bool
otherwise = Rule (Wise GPU)
forall rep. Rule rep
Skip
removeDeadGPUBodyResult BottomUp (Wise GPU)
_ Pat (LetDec (Wise GPU))
_ StmAux (ExpDec (Wise GPU))
_ Op (Wise GPU)
_ = Rule (Wise GPU)
forall rep. Rule rep
Skip

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
removeScalarCopy :: (BuilderOps rep) => TopDownRuleBasicOp rep
removeScalarCopy :: forall rep. BuilderOps rep => TopDownRuleBasicOp rep
removeScalarCopy TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
arr_x (Slice [DimIndex SubExp]
slice_x) (Var VName
v))
  | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice_x),
    Just (Index VName
arr_y (Slice [DimIndex SubExp]
slice_y), Certs
cs_y) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable,
    VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
arr_y TopDown rep
vtable,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> VName -> TopDown rep -> Bool
forall rep. VName -> VName -> SymbolTable rep -> Bool
ST.aliases VName
arr_x VName
arr_y TopDown rep
vtable,
    Just ([DimIndex SubExp]
slice_x_bef, DimFix SubExp
i, []) <- Int
-> [DimIndex SubExp]
-> Maybe ([DimIndex SubExp], DimIndex SubExp, [DimIndex SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_x,
    Just ([DimIndex SubExp]
slice_y_bef, DimFix SubExp
j, []) <- Int
-> [DimIndex SubExp]
-> Maybe ([DimIndex SubExp], DimIndex SubExp, [DimIndex SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_y = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let slice_x' :: Slice SubExp
slice_x' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_x_bef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
          slice_y' :: Slice SubExp
slice_y' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_y_bef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      VName
v' <- [Char] -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
      Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs_y (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
            Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
v'
removeScalarCopy TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  Rule rep
forall rep. Rule rep
Skip