{-# LANGUAGE TypeFamilies #-}

-- | Turn certain uses of accumulators into SegHists.
module Futhark.Optimise.HistAccs (histAccsGPU) where

import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map.Strict as M
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | A mapping from accumulator variables to their source.
type Accs rep = M.Map VName (WithAccInput rep)

type OptM = ReaderT (Scope GPU) (State VNameSource)

optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody Accs GPU
accs Body GPU
body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Accs GPU
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Accs GPU
accs (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> OptM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)

optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp Accs GPU
accs = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> OptM (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper
  where
    mapper :: Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper =
      Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPU -> Body GPU -> OptM (Body GPU)
mapOnBody = \Scope GPU
scope Body GPU
body -> Scope GPU -> OptM (Body GPU) -> OptM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (OptM (Body GPU) -> OptM (Body GPU))
-> OptM (Body GPU) -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody Accs GPU
accs Body GPU
body
        }

extractUpdate ::
  Accs rep ->
  VName ->
  Stms rep ->
  Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate :: Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms = do
  (Stm rep
stm, Stms rep
stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case Stm rep
stm of
    Let (Pat [PatElem VName
pe_v LetDec rep
_]) StmAux (ExpDec rep)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs))
      | VName
pe_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v -> do
          WithAccInput rep
acc_input <- VName -> Accs rep -> Maybe (WithAccInput rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc Accs rep
accs
          ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall a. a -> Maybe a
Just ((WithAccInput rep
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms rep
stms')
    Stm rep
_ -> do
      ((WithAccInput rep, VName, [SubExp], [SubExp])
x, Stms rep
stms'') <- Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms'
      ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((WithAccInput rep, VName, [SubExp], [SubExp])
x, Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms'')

mkHistBody :: Accs GPU -> KernelBody GPU -> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody :: Accs GPU
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Accs GPU
accs (KernelBody () Stms GPU
stms [Returns ResultManifest
rm Certs
cs (Var VName
v)]) = do
  ((WithAccInput GPU
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms GPU
stms') <- Accs GPU
-> VName
-> Stms GPU
-> Maybe ((WithAccInput GPU, VName, [SubExp], [SubExp]), Stms GPU)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs GPU
accs VName
v Stms GPU
stms
  (KernelBody GPU, WithAccInput GPU, VName)
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
is [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
vs,
      WithAccInput GPU
acc_input,
      VName
acc
    )
mkHistBody Accs GPU
_ KernelBody GPU
_ = Maybe (KernelBody GPU, WithAccInput GPU, VName)
forall a. Maybe a
Nothing

withAccLamToHistLam :: MonadFreshNames m => Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam :: Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
shape Lambda GPU
lam =
  Lambda GPU -> m (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPU -> m (Lambda GPU)) -> Lambda GPU -> m (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU
lam {lambdaParams :: [LParam GPU]
lambdaParams = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)}

addArrsToAcc ::
  (MonadBuilder m, Rep m ~ GPU) =>
  SegLevel ->
  Shape ->
  [VName] ->
  VName ->
  m (Exp GPU)
addArrsToAcc :: SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
shape [VName]
arrs VName
acc = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid"
  [VName]
gtids <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid")
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

  (VName
acc', Stms GPU
stms) <- Scope GPU -> m (VName, Stms GPU) -> m (VName, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (m (VName, Stms GPU) -> m (VName, Stms GPU))
-> (m VName -> m (VName, Stms GPU))
-> m VName
-> m (VName, Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m VName -> m (VName, Stms GPU)
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m VName -> m (VName, Stms GPU)) -> m VName -> m (VName, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
vs <- [VName] -> (VName -> m SubExp) -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> m SubExp) -> m [SubExp])
-> (VName -> m SubExp) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
      Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_elem") (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
    String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
acc String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_upd") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
gtids) [SubExp]
vs

  Type
acc_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc
  Exp GPU -> m (Exp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPU -> m (Exp GPU))
-> (KernelBody GPU -> Exp GPU) -> KernelBody GPU -> m (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (KernelBody GPU -> HostOp GPU (SOAC GPU))
-> KernelBody GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp GPU (SOAC GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type
acc_t] (KernelBody GPU -> m (Exp GPU)) -> KernelBody GPU -> m (Exp GPU)
forall a b. (a -> b) -> a -> b
$
    BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
acc')]

flatKernelBody ::
  MonadBuilder m =>
  SegSpace ->
  KernelBody (Rep m) ->
  m (SegSpace, KernelBody (Rep m))
flatKernelBody :: SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep m)
kbody = do
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SubExp
dims_prod <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"dims_prod"
      (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  let space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) [(VName
gtid, SubExp
dims_prod)]

  Stms (Rep m)
kbody_stms <- Scope (Rep m) -> m (Stms (Rep m)) -> m (Stms (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope (Rep m)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space') (m (Stms (Rep m)) -> m (Stms (Rep m)))
-> (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    let new_inds :: [TPrimExp Int64 VName]
new_inds =
          [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)) (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid)
    ([VName] -> Exp (Rep m) -> m ())
-> [[VName]] -> [Exp (Rep m)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (((VName, SubExp) -> [VName]) -> [(VName, SubExp)] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> [VName])
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space))
      ([Exp (Rep m)] -> m ()) -> m [Exp (Rep m)] -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> [TPrimExp Int64 VName] -> m [Exp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
    Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ KernelBody (Rep m) -> Stms (Rep m)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Rep m)
kbody

  (SegSpace, KernelBody (Rep m)) -> m (SegSpace, KernelBody (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space', KernelBody (Rep m)
kbody {kernelBodyStms :: Stms (Rep m)
kernelBodyStms = Stms (Rep m)
kbody_stms})

optimiseStm :: Accs GPU -> Stm GPU -> OptM (Stms GPU)
-- TODO: this is very restricted currently, but shows the idea.
optimiseStm :: Accs GPU
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Accs GPU
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam)) = do
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Body GPU
body' <- Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody Accs GPU
accs' (Body GPU -> OptM (Body GPU)) -> Body GPU -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam
    let lam' :: Lambda GPU
lam' = Lambda GPU
lam {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'}
    Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam'
  where
    acc_names :: [VName]
acc_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam
    accs' :: Accs GPU
accs' = [(VName, WithAccInput GPU)] -> Accs GPU
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [WithAccInput GPU] -> [(VName, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
acc_names [WithAccInput GPU]
inputs) Accs GPU -> Accs GPU -> Accs GPU
forall a. Semigroup a => a -> a -> a
<> Accs GPU
accs
optimiseStm Accs GPU
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap lvl space _ kbody))))
  | Accs GPU
accs Accs GPU -> Accs GPU -> Bool
forall a. Eq a => a -> a -> Bool
/= Accs GPU
forall a. Monoid a => a
mempty,
    Just (KernelBody GPU
kbody', (Shape
acc_shape, [VName]
_, Just (Lambda GPU
acc_lam, [SubExp]
acc_nes)), VName
acc) <-
      Accs GPU
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Accs GPU
accs KernelBody GPU
kbody,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam = Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU ()
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
hist_dests <- [SubExp]
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
acc_nes ((SubExp -> BuilderT GPU (State VNameSource) VName)
 -> BuilderT GPU (State VNameSource) [VName])
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
ne ->
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"hist_dest" (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
acc_shape SubExp
ne

      Lambda GPU
acc_lam' <- Shape
-> Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
acc_shape Lambda GPU
acc_lam

      let ts' :: [Type]
ts' =
            Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
acc_shape) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
              [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam
          histop :: HistOp GPU
histop =
            HistOp :: forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp
              { histShape :: Shape
histShape = Shape
acc_shape,
                histRaceFactor :: SubExp
histRaceFactor = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1,
                histDest :: [VName]
histDest = [VName]
hist_dests,
                histNeutral :: [SubExp]
histNeutral = [SubExp]
acc_nes,
                histOpShape :: Shape
histOpShape = Shape
forall a. Monoid a => a
mempty,
                histOp :: Lambda GPU
histOp = Lambda GPU
acc_lam'
              }

      (SegSpace
space', KernelBody GPU
kbody'') <- SegSpace
-> KernelBody (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
     GPU
     (State VNameSource)
     (SegSpace, KernelBody (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep (BuilderT GPU (State VNameSource)))
KernelBody GPU
kbody'

      [VName]
hist_dest_upd <-
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"hist_dest_upd" (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPU]
-> [Type]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space' [HistOp GPU
histop] [Type]
ts' KernelBody GPU
kbody''

      Stm GPU -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPU -> Builder GPU ())
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Builder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Builder GPU ())
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Shape
-> [VName]
-> VName
-> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
acc_shape [VName]
hist_dest_upd VName
acc
optimiseStm Accs GPU
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) =
  Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stms GPU)
-> OptM (Exp GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp Accs GPU
accs Exp GPU
e

optimiseStms :: Accs GPU -> Stms GPU -> OptM (Stms GPU)
optimiseStms :: Accs GPU
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Accs GPU
accs Stms GPU
stms =
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
    [Stms GPU] -> Stms GPU
forall a. Monoid a => [a] -> a
mconcat ([Stms GPU] -> Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> [Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Accs GPU
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Accs GPU
accs) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)

-- | The pass for GPU kernels.
histAccsGPU :: Pass GPU GPU
histAccsGPU :: Pass GPU GPU
histAccsGPU =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"hist accs" String
"Turn certain accumulations into histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
    (Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall (m :: * -> *).
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (State VNameSource (Stms GPU)
    -> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> m (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU) -> m (Stms GPU))
-> State VNameSource (Stms GPU) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Accs GPU
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Accs GPU
forall a. Monoid a => a
mempty Stms GPU
stms) Scope GPU
scope