module Futhark.CodeGen.ImpGen.Multicore.SegRed
  ( compileSegRed,
    compileSegRed',
    DoSegBody,
  )
where

import Control.Monad
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename (renameLambda)
import Prelude hiding (quot, rem)

type DoSegBody = (([[(SubExp, [Imp.TExp Int64])]] -> MulticoreGen ()) -> MulticoreGen ())

-- | Generate code for a SegRed construct
compileSegRed ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
compileSegRed :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
compileSegRed Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks =
  Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
compileSegRed' Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont ->
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$ do
        let map_arrs :: [PatElem LParamMem]
map_arrs = forall a. Int -> [a] -> [a]
drop (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_arrs [KernelResult]
map_res

      [[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  TV Int32 ->
  DoSegBody ->
  MulticoreGen Imp.MCCode
compileSegRed' :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
compileSegRed' Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
      Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
nonsegmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
  | Bool
otherwise =
      Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
segmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody

-- | A SegBinOp with auxiliary information.
data SegBinOpSlug = SegBinOpSlug
  { SegBinOpSlug -> SegBinOp MCMem
slugOp :: SegBinOp MCMem,
    -- | The array in which we write the intermediate results, indexed
    -- by the flat/physical thread ID.
    SegBinOpSlug -> [VName]
slugResArrs :: [VName]
  }

slugBody :: SegBinOpSlug -> Body MCMem
slugBody :: SegBinOpSlug -> Body MCMem
slugBody = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

accParams, nextParams :: SegBinOpSlug -> [LParam MCMem]
accParams :: SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug

renameSlug :: SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug :: SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug SegBinOpSlug
slug = do
  let op :: SegBinOp MCMem
op = SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
  let lambda :: Lambda MCMem
lambda = forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
op
  Lambda MCMem
lambda' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda MCMem
lambda
  let op' :: SegBinOp MCMem
op' = SegBinOp MCMem
op {segBinOpLambda :: Lambda MCMem
segBinOpLambda = Lambda MCMem
lambda'}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure SegBinOpSlug
slug {slugOp :: SegBinOp MCMem
slugOp = SegBinOp MCMem
op'}

nonsegmentedReduction ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  TV Int32 ->
  DoSegBody ->
  MulticoreGen Imp.MCCode
nonsegmentedReduction :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
nonsegmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody = forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
  [[VName]]
thread_res_arrs <- String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
"reduce_stage_1_tid_res_arr" (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
nsubtasks) [SegBinOp MCMem]
reds
  let slugs1 :: [SegBinOpSlug]
slugs1 = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds [[VName]]
thread_res_arrs
      nsubtasks' :: TExp Int32
nsubtasks' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
nsubtasks

  -- Are all the operators commutative?
  let comm :: Bool
comm = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Commutativity
Commutative) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm) [SegBinOp MCMem]
reds
  let dims :: [[SubExp]]
dims = forall a b. (a -> b) -> [a] -> [b]
map (forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Shape
slugShape) [SegBinOpSlug]
slugs1
  let isScalar :: MemInfo d u ret -> Bool
isScalar MemInfo d u ret
x = case MemInfo d u ret
x of MemPrim PrimType
_ -> Bool
True; MemInfo d u ret
_ -> Bool
False
  -- Are we only working on scalar arrays?
  let scalars :: Bool
scalars = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {d} {u} {ret}. MemInfo d u ret -> Bool
isScalar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [LParam MCMem]
slugParams) [SegBinOpSlug]
slugs1 Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== []) [[SubExp]]
dims
  -- Are we working with vectorized inner maps?
  let inner_map :: Bool
inner_map = [] forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [[SubExp]]
dims

  let path :: ReductionStage1
path
        | Bool
comm Bool -> Bool -> Bool
&& Bool
scalars = ReductionStage1
reductionStage1CommScalar
        | Bool
inner_map = ReductionStage1
reductionStage1Array
        | Bool
scalars = ReductionStage1
reductionStage1NonCommScalar
        | Bool
otherwise = ReductionStage1
reductionStage1Fallback
  ReductionStage1
path SegSpace
space [SegBinOpSlug]
slugs1 DoSegBody
kbody

  [SegBinOp MCMem]
reds2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
reds
  let slugs2 :: [SegBinOpSlug]
slugs2 = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds2 [[VName]]
thread_res_arrs
  Pat LParamMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pat LParamMem
pat SegSpace
space TExp Int32
nsubtasks' [SegBinOpSlug]
slugs2

-- Generate code that declares the params for the binop
genBinOpParams :: [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams :: [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs =
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
slugParams [SegBinOpSlug]
slugs

-- Generate code that declares accumulators, return a list of these
genAccumulators :: [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators :: [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs =
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
    let shape :: Shape
shape = forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) -> do
      -- Declare accumulator variable.
      VName
acc <-
        case forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p of
          Prim PrimType
pt
            | Shape
shape forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty ->
                forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"local_acc" PrimType
pt
            | Bool
otherwise ->
                forall {k} (rep :: k) r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"local_acc" PrimType
pt Shape
shape Space
DefaultSpace
          Type
_ ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p

      -- Now neutral-initialise the accumulator.
      forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
        forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []

      forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc

-- Datatype to represent all the different ways we can generate
-- code for a reduction.
data RedLoopType
  = RedSeq -- Fully sequential
  | RedComm -- Commutative scalar
  | RedNonComm -- Noncommutative scalar
  | RedNested -- Nested vectorized operator
  | RedUniformize -- Uniformize over scalar acc

-- Given a type of reduction and the loop index, should we wrap
-- the loop body in some extra code?
getRedLoop ::
  RedLoopType ->
  Imp.TExp Int64 ->
  (Imp.TExp Int64 -> MulticoreGen ()) ->
  MulticoreGen ()
getRedLoop :: RedLoopType
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
getRedLoop RedLoopType
RedNonComm TExp Int64
_ = (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop
getRedLoop RedLoopType
RedUniformize TExp Int64
uni = \TExp Int64 -> MulticoreGen ()
body -> TExp Int64 -> MulticoreGen ()
body TExp Int64
uni
getRedLoop RedLoopType
_ TExp Int64
_ = \TExp Int64 -> MulticoreGen ()
body -> TExp Int64 -> MulticoreGen ()
body TExp Int64
0

-- Given a type of reduction, should we perform extracts on
-- the accumulator?
getExtract ::
  RedLoopType ->
  Imp.TExp Int64 ->
  MulticoreGen Imp.MCCode ->
  MulticoreGen ()
getExtract :: RedLoopType -> TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
getExtract RedLoopType
RedNonComm = TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
extractVectorLane
getExtract RedLoopType
RedUniformize = TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
extractVectorLane
getExtract RedLoopType
_ = \TExp Int64
_ MulticoreGen MCCode
body -> MulticoreGen MCCode
body forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit

-- Given a type of reduction, should we vectorize the inner
-- map, if it exists?
getNestLoop ::
  RedLoopType ->
  Shape ->
  ([Imp.TExp Int64] -> MulticoreGen ()) ->
  MulticoreGen ()
getNestLoop :: RedLoopType
-> Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
getNestLoop RedLoopType
RedNested = Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
sLoopNestVectorized
getNestLoop RedLoopType
_ = forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest

-- Given a list of accumulators, use them as the source
-- data for reduction.
redSourceAccs :: [[VName]] -> DoSegBody
redSourceAccs :: [[VName]] -> DoSegBody
redSourceAccs [[VName]]
slug_local_accs [[(SubExp, [TExp Int64])]] -> MulticoreGen ()
m =
  [[(SubExp, [TExp Int64])]] -> MulticoreGen ()
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map (\VName
x -> (VName -> SubExp
Var VName
x, []))) [[VName]]
slug_local_accs

-- Generate a reduction loop for uniformizing vectors
genPostbodyReductionLoop ::
  [[VName]] ->
  [SegBinOpSlug] ->
  [[VName]] ->
  SegSpace ->
  Imp.TExp Int64 ->
  MulticoreGen ()
genPostbodyReductionLoop :: [[VName]]
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genPostbodyReductionLoop [[VName]]
accs =
  RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedUniformize ([[VName]] -> DoSegBody
redSourceAccs [[VName]]
accs)

-- Generate a potentially vectorized body of code that performs reduction
-- when put inside a chunked loop.
genReductionLoop ::
  RedLoopType ->
  DoSegBody ->
  [SegBinOpSlug] ->
  [[VName]] ->
  SegSpace ->
  Imp.TExp Int64 ->
  MulticoreGen ()
genReductionLoop :: RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
typ DoSegBody
kbodymap [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space TExp Int64
i = do
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
  DoSegBody
kbodymap forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]]
all_red_res' -> do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[(SubExp, [TExp Int64])]]
all_red_res' [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) forall a b. (a -> b) -> a -> b
$ \([(SubExp, [TExp Int64])]
red_res, SegBinOpSlug
slug, [VName]
local_accs) ->
      RedLoopType
-> Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
getNestLoop RedLoopType
typ (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
        let lamtypes :: [Type]
lamtypes = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
        -- Load accum params
        RedLoopType
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
getRedLoop RedLoopType
typ TExp Int64
i forall a b. (a -> b) -> a -> b
$ \TExp Int64
uni -> do
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load accum params" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [VName]
local_accs [Type]
lamtypes) forall a b. (a -> b) -> a -> b
$
              \(Param LParamMem
p, VName
local_acc, Type
t) ->
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType Type
t) forall a b. (a -> b) -> a -> b
$ do
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
local_acc) [TExp Int64]
vec_is

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load next params" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) -> do
              RedLoopType -> TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
getExtract RedLoopType
typ TExp Int64
uni forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"SegRed body" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_accs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
                \(VName
local_acc, SubExp
se) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
local_acc [TExp Int64]
vec_is SubExp
se []

-- Generate code to write back results from the accumulators
genWriteBack :: [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack :: [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [VName]
local_accs) ->
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug) [VName]
local_accs) forall a b. (a -> b) -> a -> b
$ \(VName
acc, VName
local_acc) ->
      forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc [forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space] (VName -> SubExp
Var VName
local_acc) []

type ReductionStage1 = SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()

-- Pure sequential codegen with no fancy vectorization
reductionStage1Fallback :: ReductionStage1
reductionStage1Fallback :: ReductionStage1
reductionStage1Fallback SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
  MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    -- Declare params
    [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
    [[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
    -- Generate main reduction loop
    String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$
      RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedSeq DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
    -- Write back results
    [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params

-- Codegen for noncommutative scalar reduction. We vectorize the
-- kernel body, and do the reduction sequentially.
reductionStage1NonCommScalar :: ReductionStage1
reductionStage1NonCommScalar :: ReductionStage1
reductionStage1NonCommScalar SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
  MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
      -- Declare params
      [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
      [[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
      -- Generate main reduction loop
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$
        RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedNonComm DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
      -- Write back results
      [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params

-- Codegen for a commutative reduction on scalar arrays
-- In this case, we can generate an efficient interleaved reduction
reductionStage1CommScalar :: ReductionStage1
reductionStage1CommScalar :: ReductionStage1
reductionStage1CommScalar SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
  MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    -- Rename lambda params in slugs to get a new set of them
    [SegBinOpSlug]
slugs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug [SegBinOpSlug]
slugs
    MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
      -- Declare one set of params uniform
      [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs'
      [[VName]]
slug_local_accs_uni <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs'
      -- Declare the other varying
      [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
      [[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
      -- Generate the main reduction loop over vectors
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$
        RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedComm DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
      -- Now reduce over those vector accumulators to get scalar results
      (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop forall a b. (a -> b) -> a -> b
$
        [[VName]]
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genPostbodyReductionLoop [[VName]]
slug_local_accs [SegBinOpSlug]
slugs' [[VName]]
slug_local_accs_uni SegSpace
space
      -- And write back the results
      [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs_uni SegSpace
space
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params

-- Codegen for a reduction on arrays, where the body is a perfect nested map.
-- We vectorize just the inner map.
reductionStage1Array :: ReductionStage1
reductionStage1Array :: ReductionStage1
reductionStage1Array SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
  MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    -- Declare params
    MCCode
lparams <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
    ([[VName]]
slug_local_accs, MCCode
uniform_prebody) <- forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' forall a b. (a -> b) -> a -> b
$ [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
    -- Put the accumulators outside of the kernel, so they are forced uniform
    forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit MCCode
uniform_prebody
    MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
      -- Put the lambda params inside the kernel so they are varying
      forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit MCCode
lparams
      -- Generate the main reduction loop
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$
        RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedNested DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
      -- Write back results
      [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params

reductionStage2 ::
  Pat LetDecMem ->
  SegSpace ->
  Imp.TExp Int32 ->
  [SegBinOpSlug] ->
  MulticoreGen ()
reductionStage2 :: Pat LParamMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pat LParamMem
pat SegSpace
space TExp Int32
nsubtasks [SegBinOpSlug]
slugs = do
  let per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks (forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      phys_id :: TExp Int64
phys_id = forall a. a -> TPrimExp Int64 a
Imp.le64 (SegSpace -> VName
segFlat SegSpace
space)
  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"neutral-initialise the output" forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [[PatElem LParamMem]]
per_red_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
red, [PatElem LParamMem]
red_res) ->
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
ne) ->
        forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is SubExp
ne []

  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
slugParams [SegBinOpSlug]
slugs

  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
nsubtasks forall a b. (a -> b) -> a -> b
$ \TExp Int32
i' -> do
    forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64 forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
i'
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Apply main thread reduction" forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElem LParamMem]]
per_red_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElem LParamMem]
red_res) ->
        forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load acc params" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [PatElem LParamMem]
red_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
              forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load next params" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc) ->
              forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) (TExp Int64
phys_id forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"red body" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_res forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
                \(PatElem LParamMem
pe, SubExp
se') -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is SubExp
se' []

-- Each thread reduces over the number of segments
-- each of which is done sequentially
-- Maybe we should select the work of the inner loop
-- based on n_segments and dimensions etc.
segmentedReduction ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  DoSegBody ->
  MulticoreGen Imp.MCCode
segmentedReduction :: Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
segmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody =
  forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    MCCode
body <- Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
compileSegRedBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
    [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
    forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segmented_segred" MCCode
body [Param]
free_params

-- Currently, this is only used as part of SegHist calculations, never alone.
compileSegRedBody ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  DoSegBody ->
  MulticoreGen Imp.MCCode
compileSegRedBody :: Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
compileSegRedBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody = do
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
      inner_bound :: TExp Int64
inner_bound = forall a. [a] -> a
last [TExp Int64]
ns_64
  forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

  let per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
  -- Perform sequential reduce on inner most dimension
  forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall b c a. (b -> c) -> (a -> b) -> a -> c
. MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$
    String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$ \TExp Int64
n_segments -> do
      TExp Int64
flat_idx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" forall a b. (a -> b) -> a -> b
$ TExp Int64
n_segments forall a. Num a => a -> a -> a
* TExp Int64
inner_bound
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"neutral-initialise the accumulators" forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_red_pes [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
red) ->
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
ne) ->
            forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
              forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"main body" forall a b. (a -> b) -> a -> b
$ do
        forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp MCMem]
reds
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
(<--)
            (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k). VName -> PrimType -> TV t
`mkTV` PrimType
int64) forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [VName]
is)
            (forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TExp Int64]
ns_64) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
n_segments))
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
is) TExp Int64
i
          DoSegBody
kbody forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]]
red_res' -> do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [SegBinOp MCMem]
reds [[(SubExp, [TExp Int64])]]
red_res') forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
red, [(SubExp, [TExp Int64])]
res') ->
              forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
                forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load accum" forall a b. (a -> b) -> a -> b
$ do
                  let acc_params :: [Param LParamMem]
acc_params = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
                    forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

                forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load new val" forall a b. (a -> b) -> a -> b
$ do
                  let next_params :: [Param LParamMem]
next_params = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
next_params [(SubExp, [TExp Int64])]
res') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
                    forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

                forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction" forall a b. (a -> b) -> a -> b
$ do
                  let lbody :: Body MCMem
lbody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
                  forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body MCMem
lbody) forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write back to res" forall a b. (a -> b) -> a -> b
$
                      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body MCMem
lbody) forall a b. (a -> b) -> a -> b
$
                        \(PatElem LParamMem
pe, SubExp
se') -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se' []