{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.StreamKernel
  ( segThreadCapped,
    streamRed,
    streamMap,
  )
where

import Control.Monad
import Control.Monad.Writer
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Kernels hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    PatElem,
    Pattern,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.Tools
import Prelude hiding (quot)

data KernelSize = KernelSize
  { -- | Int64
    KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
    -- | Int32
    KernelSize -> SubExp
kernelNumThreads :: SubExp
  }
  deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)

numberOfGroups ::
  (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
  String ->
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
  Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
  SubExp
num_groups <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
  SubExp
num_threads <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_groups, SubExp
num_threads)

blockedKernelSize ::
  (MonadBinder m, Lore m ~ Kernels) =>
  String ->
  SubExp ->
  m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
  SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  (SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size

  SubExp
per_thread_elements <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_thread_elements"
      (ExpT Kernels -> m SubExp) -> m (ExpT Kernels) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
num_threads)

  KernelSize -> m KernelSize
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads

splitArrays ::
  (MonadBinder m, Lore m ~ Kernels) =>
  VName ->
  [VName] ->
  SplitOrdering ->
  SubExp ->
  SubExp ->
  SubExp ->
  [VName] ->
  m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
chunk_size] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
  case SplitOrdering
ordering of
    SplitOrdering
SplitContiguous -> do
      SubExp
offset <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
i SubExp
elems_per_i
      (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
    SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
  where
    contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
      Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))]
      [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

    stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
      Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
      [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

partitionChunkedKernelFoldParameters ::
  Int ->
  [Param dec] ->
  (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
num_accs (Param dec
i_param : Param dec
chunk_param : [Param dec]
params) =
  let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
   in (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
i_param, Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)
partitionChunkedKernelFoldParameters Int
_ [Param dec]
_ =
  String -> (VName, Param dec, [Param dec], [Param dec])
forall a. HasCallStack => String -> a
error String
"partitionChunkedKernelFoldParameters: lambda takes too few parameters"

blockedPerThread ::
  (MonadBinder m, Lore m ~ Kernels) =>
  VName ->
  SubExp ->
  KernelSize ->
  StreamOrd ->
  Lambda (Lore m) ->
  Int ->
  [VName] ->
  m ([PatElemT Type], [PatElemT Type])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda (Lore m)
lam Int
num_nonconcat [VName]
arrs = do
  let (VName
_, Param Type
chunk_size, [], [Param Type]
arr_params) =
        Int
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
0 ([Param Type] -> (VName, Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
LambdaT Kernels
lam

      ordering' :: SplitOrdering
ordering' =
        case StreamOrd
ordering of
          StreamOrd
InOrder -> SplitOrdering
SplitContiguous
          StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
      red_ts :: [Type]
red_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
      map_ts :: [Type]
map_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam

  SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
  VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays
    (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)
    ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_params)
    SplitOrdering
ordering'
    SubExp
w
    (VName -> SubExp
Var VName
thread_gtid)
    SubExp
per_thread
    [VName]
arrs

  [PatElemT Type]
chunk_red_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
red_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
red_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
    PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name Type
red_t
  [PatElemT Type]
chunk_map_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
map_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
map_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
    PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type
map_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)

  let ([SubExp]
chunk_red_ses, [SubExp]
chunk_map_ses) =
        Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT Kernels -> [SubExp]) -> BodyT Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam

  Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam)
      Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
        [ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElemT Type
pe, SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_red_pes [SubExp]
chunk_red_ses
        ]
      Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
        [ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElemT Type
pe, SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_map_pes [SubExp]
chunk_map_ses
        ]

  ([PatElemT Type], [PatElemT Type])
-> m ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes)

-- | Given a chunked fold lambda that takes its initial accumulator
-- value as parameters, bind those parameters to the neutral element
-- instead.
kerneliseLambda ::
  MonadFreshNames m =>
  [SubExp] ->
  Lambda Kernels ->
  m (Lambda Kernels)
kerneliseLambda :: [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
lam = do
  VName
thread_index <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"thread_index"
  let thread_index_param :: Param Type
thread_index_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
thread_index (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      (Param Type
fold_chunk_param, [Param Type]
fold_acc_params, [Param Type]
fold_inp_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam

      mkAccInit :: Param dec -> SubExp -> Stm lore
mkAccInit Param dec
p (Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
p =
          [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      mkAccInit Param dec
p SubExp
x = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
      acc_init_bnds :: Stms Kernels
acc_init_bnds = [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp -> Stm Kernels)
-> [Param Type] -> [SubExp] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param Type -> SubExp -> Stm Kernels
forall lore dec.
(Bindable lore, Typed dec) =>
Param dec -> SubExp -> Stm lore
mkAccInit [Param Type]
fold_acc_params [SubExp]
nes
  LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
    LambdaT Kernels
lam
      { lambdaBody :: BodyT Kernels
lambdaBody =
          Stms Kernels -> BodyT Kernels -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Stms Kernels
acc_init_bnds (BodyT Kernels -> BodyT Kernels) -> BodyT Kernels -> BodyT Kernels
forall a b. (a -> b) -> a -> b
$
            LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam,
        lambdaParams :: [LParam Kernels]
lambdaParams =
          Param Type
thread_index_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
          Param Type
fold_chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
          [Param Type]
fold_inp_params
      }

prepareStream ::
  (MonadBinder m, Lore m ~ Kernels) =>
  KernelSize ->
  [(VName, SubExp)] ->
  SubExp ->
  Commutativity ->
  Lambda Kernels ->
  [SubExp] ->
  [VName] ->
  m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = do
  let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
  let (StreamOrd
ordering, SplitOrdering
split_ordering) =
        case Commutativity
comm of
          Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
          Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)

  LambdaT Kernels
fold_lam' <- [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
fold_lam

  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
  KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
    Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
      Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
        ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes) <-
          VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore (BinderT Kernels (State VNameSource)))
-> Int
-> [VName]
-> BinderT
     Kernels (State VNameSource) ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda (Lore (BinderT Kernels (State VNameSource)))
LambdaT Kernels
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
        let concatReturns :: PatElemT Type -> KernelResult
concatReturns PatElemT Type
pe =
              SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
        [KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( (PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify (SubExp -> KernelResult)
-> (PatElemT Type -> SubExp) -> PatElemT Type -> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
chunk_red_pes
              [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> KernelResult
concatReturns [PatElemT Type]
chunk_map_pes
          )

  let ([Type]
redout_ts, [Type]
mapout_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT Kernels
fold_lam
      ts :: [Type]
ts = [Type]
redout_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
mapout_ts

  (SubExp, SegSpace, [Type], KernelBody Kernels)
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SegSpace
space, [Type]
ts, KernelBody Kernels
kbody)

streamRed ::
  (MonadFreshNames m, HasScope Kernels m) =>
  MkSegLevel Kernels m ->
  Pattern Kernels ->
  SubExp ->
  Commutativity ->
  Lambda Kernels ->
  Lambda Kernels ->
  [SubExp] ->
  [VName] ->
  m (Stms Kernels)
streamRed :: MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (Stms Kernels)
streamRed MkSegLevel Kernels m
mk_lvl Pattern Kernels
pat SubExp
w Commutativity
comm LambdaT Kernels
red_lam LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  -- The strategy here is to rephrase the stream reduction as a
  -- non-segmented SegRed that does explicit chunking within its body.
  -- First, figure out how many threads to use for this.
  KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w

  let ([PatElemT Type]
redout_pes, [PatElemT Type]
mapout_pes) = Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat
  (PatternT Type
redout_pat, [(VName, SubExp)]
ispace, BinderT Kernels m ()
read_dummy) <- Pattern (Lore (BinderT Kernels m))
-> BinderT
     Kernels
     m
     (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
      BinderT Kernels m ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim (Pattern (Lore (BinderT Kernels m))
 -> BinderT
      Kernels
      m
      (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
       BinderT Kernels m ()))
-> Pattern (Lore (BinderT Kernels m))
-> BinderT
     Kernels
     m
     (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
      BinderT Kernels m ())
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
redout_pes
  let pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
redout_pat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
mapout_pes

  (SubExp
_, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs

  SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_red" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat' (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$
    Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
      SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> SegSpace
-> [SegBinOp Kernels]
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed
          SegLevel
lvl
          SegSpace
kspace
          [Commutativity
-> LambdaT Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm LambdaT Kernels
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty]
          [Type]
ts
          KernelBody Kernels
kbody

  BinderT Kernels m ()
read_dummy

-- Similar to streamRed, but without the last reduction.
streamMap ::
  (MonadFreshNames m, HasScope Kernels m) =>
  MkSegLevel Kernels m ->
  [String] ->
  [PatElem Kernels] ->
  SubExp ->
  Commutativity ->
  Lambda Kernels ->
  [SubExp] ->
  [VName] ->
  m ((SubExp, [VName]), Stms Kernels)
streamMap :: MkSegLevel Kernels m
-> [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap MkSegLevel Kernels m
mk_lvl [String]
out_desc [PatElem Kernels]
mapout_pes SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (SubExp, [VName])
 -> m ((SubExp, [VName]), Stms Kernels))
-> BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w

  (SubExp
threads, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs

  let redout_ts :: [Type]
redout_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts

  [PatElemT Type]
redout_pes <- [(String, Type)]
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [Type]
redout_ts) (((String, Type) -> BinderT Kernels m (PatElemT Type))
 -> BinderT Kernels m [PatElemT Type])
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \(String
desc, Type
t) ->
    VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem (VName -> Type -> PatElemT Type)
-> BinderT Kernels m VName
-> BinderT Kernels m (Type -> PatElemT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BinderT Kernels m (Type -> PatElemT Type)
-> BinderT Kernels m Type -> BinderT Kernels m (PatElemT Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> BinderT Kernels m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)

  let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [PatElemT Type]
redout_pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
[PatElem Kernels]
mapout_pes
  SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_map" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
ts KernelBody Kernels
kbody

  (SubExp, [VName]) -> BinderT Kernels m (SubExp, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
threads, (PatElemT Type -> VName) -> [PatElemT Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT Type]
redout_pes)

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped :: MkSegLevel Kernels m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w <-
    String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"nest_size"
      (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
  SubExp
group_size <- String -> SizeClass -> BinderT Kernels m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <-
        String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups"
          (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
            (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
            (SubExp -> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)
            (SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
    NoRecommendation SegVirt
v -> do
      (SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BinderT Kernels m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v