{-# LANGUAGE TypeFamilies #-}

-- | Turn certain uses of accumulators into SegHists.
module Futhark.Optimise.HistAccs (histAccsGPU) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Map.Strict qualified as M
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | A mapping from accumulator variables to their source.
type Accs rep = M.Map VName (WithAccInput rep)

type OptM = ReaderT (Scope GPU) (State VNameSource)

optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody :: Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU) -> Stms GPU -> OptM (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs (forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Body rep -> Result
bodyResult Body GPU
body)

optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp :: Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs = forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper
  where
    mapper :: Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper =
      forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPU -> Body GPU -> OptM (Body GPU)
mapOnBody = \Scope GPU
scope Body GPU
body -> forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope forall a b. (a -> b) -> a -> b
$ Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body
        }

extractUpdate ::
  Accs rep ->
  VName ->
  Stms rep ->
  Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate :: forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms = do
  (Stm rep
stm, Stms rep
stms') <- forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case Stm rep
stm of
    Let (Pat [PatElem VName
pe_v LetDec rep
_]) StmAux (ExpDec rep)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs))
      | VName
pe_v forall a. Eq a => a -> a -> Bool
== VName
v -> do
          WithAccInput rep
acc_input <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc Accs rep
accs
          forall a. a -> Maybe a
Just ((WithAccInput rep
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms rep
stms')
    Stm rep
_ -> do
      ((WithAccInput rep, VName, [SubExp], [SubExp])
x, Stms rep
stms'') <- forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ((WithAccInput rep, VName, [SubExp], [SubExp])
x, forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Stms rep
stms'')

mkHistBody :: Accs GPU -> KernelBody GPU -> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody :: Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs (KernelBody () Stms GPU
stms [Returns ResultManifest
rm Certs
cs (Var VName
v)]) = do
  ((WithAccInput GPU
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms GPU
stms') <- forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Map VName (WithAccInput GPU)
accs VName
v Stms GPU
stms
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
is forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
vs,
      WithAccInput GPU
acc_input,
      VName
acc
    )
mkHistBody Map VName (WithAccInput GPU)
_ KernelBody GPU
_ = forall a. Maybe a
Nothing

withAccLamToHistLam :: MonadFreshNames m => Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam :: forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
shape Lambda GPU
lam =
  forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ Lambda GPU
lam {lambdaParams :: [LParam GPU]
lambdaParams = forall a. Int -> [a] -> [a]
drop (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)}

addArrsToAcc ::
  (MonadBuilder m, Rep m ~ GPU) =>
  SegLevel ->
  Shape ->
  [VName] ->
  VName ->
  m (Exp GPU)
addArrsToAcc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
shape [VName]
arrs VName
acc = do
  VName
flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid"
  [VName]
gtids <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid")
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape

  (VName
acc', Stms GPU
stms) <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
vs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
arr forall a. Semigroup a => a -> a -> a
<> String
"_elem") forall a b. (a -> b) -> a -> b
$
        forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
            Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
acc forall a. Semigroup a => a -> a -> a
<> String
"_upd") forall a b. (a -> b) -> a -> b
$
      forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
gtids) [SubExp]
vs

  Type
acc_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type
acc_t] forall a b. (a -> b) -> a -> b
$
    forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
acc')]

flatKernelBody ::
  MonadBuilder m =>
  SegSpace ->
  KernelBody (Rep m) ->
  m (SegSpace, KernelBody (Rep m))
flatKernelBody :: forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep m)
kbody = do
  VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SubExp
dims_prod <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"dims_prod"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  let space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) [(VName
gtid, SubExp
dims_prod)]

  Stms (Rep m)
kbody_stms <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ do
    let new_inds :: [TPrimExp Int64 VName]
new_inds =
          forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)) (SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid)
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space))
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Rep m)
kbody

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space', KernelBody (Rep m)
kbody {kernelBodyStms :: Stms (Rep m)
kernelBodyStms = Stms (Rep m)
kbody_stms})

optimiseStm :: Accs GPU -> Stm GPU -> OptM (Stms GPU)
-- TODO: this is very restricted currently, but shows the idea.
optimiseStm :: Map VName (WithAccInput GPU) -> Stm GPU -> OptM (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam)) = do
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)) forall a b. (a -> b) -> a -> b
$ do
    Body GPU
body' <- Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs' forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam
    let lam' :: Lambda GPU
lam' = Lambda GPU
lam {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'}
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam'
  where
    acc_names :: [VName]
acc_names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam
    accs' :: Map VName (WithAccInput GPU)
accs' = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
acc_names [WithAccInput GPU]
inputs) forall a. Semigroup a => a -> a -> a
<> Map VName (WithAccInput GPU)
accs
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody GPU
kbody))))
  | Map VName (WithAccInput GPU)
accs forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty,
    Just (KernelBody GPU
kbody', (Shape
acc_shape, [VName]
_, Just (Lambda GPU
acc_lam, [SubExp]
acc_nes)), VName
acc) <-
      Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs KernelBody GPU
kbody,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
      [VName]
hist_dests <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
acc_nes forall a b. (a -> b) -> a -> b
$ \SubExp
ne ->
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"hist_dest" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
acc_shape SubExp
ne

      Lambda GPU
acc_lam' <- forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
acc_shape Lambda GPU
acc_lam

      let ts' :: [Type]
ts' =
            forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank Shape
acc_shape) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
              forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam
          histop :: HistOp GPU
histop =
            HistOp
              { histShape :: Shape
histShape = Shape
acc_shape,
                histRaceFactor :: SubExp
histRaceFactor = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1,
                histDest :: [VName]
histDest = [VName]
hist_dests,
                histNeutral :: [SubExp]
histNeutral = [SubExp]
acc_nes,
                histOpShape :: Shape
histOpShape = forall a. Monoid a => a
mempty,
                histOp :: Lambda GPU
histOp = Lambda GPU
acc_lam'
              }

      (SegSpace
space', KernelBody GPU
kbody'') <- forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody GPU
kbody'

      [VName]
hist_dest_upd <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"hist_dest_upd" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space' [HistOp GPU
histop] [Type]
ts' KernelBody GPU
kbody''

      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
acc_shape [VName]
hist_dest_upd VName
acc
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) =
  forall rep. Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs Exp GPU
e

optimiseStms :: Accs GPU -> Stms GPU -> OptM (Stms GPU)
optimiseStms :: Map VName (WithAccInput GPU) -> Stms GPU -> OptM (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs Stms GPU
stms =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) forall a b. (a -> b) -> a -> b
$
    forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Map VName (WithAccInput GPU) -> Stm GPU -> OptM (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)

-- | The pass for GPU kernels.
histAccsGPU :: Pass GPU GPU
histAccsGPU :: Pass GPU GPU
histAccsGPU =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"hist accs" String
"Turn certain accumulations into histograms" forall a b. (a -> b) -> a -> b
$
    forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$
        forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Map VName (WithAccInput GPU) -> Stms GPU -> OptM (Stms GPU)
optimiseStms forall a. Monoid a => a
mempty Stms GPU
stms) Scope GPU
scope