{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.StreamKernel
  ( segThreadCapped,
    streamRed,
    streamMap,
  )
where

import Control.Monad
import Control.Monad.Writer
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    Pat,
    PatElem,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Prelude hiding (quot)

data KernelSize = KernelSize
  { -- | Int64
    KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
    -- | Int32
    KernelSize -> SubExp
kernelNumThreads :: SubExp
  }
  deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)

numberOfGroups ::
  (MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
  String ->
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
  Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
  SubExp
num_groups <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_groups" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
        SizeOp -> HostOp (Rep m) inner
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Rep m) inner) -> SizeOp -> HostOp (Rep m) inner
forall a b. (a -> b) -> a -> b
$
          SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
  SubExp
num_threads <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_groups, SubExp
num_threads)

blockedKernelSize ::
  (MonadBuilder m, Rep m ~ GPU) =>
  String ->
  SubExp ->
  m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
  SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  (SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size

  SubExp
per_thread_elements <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"per_thread_elements"
      (Exp GPU -> m SubExp) -> m (Exp GPU) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
num_threads)

  KernelSize -> m KernelSize
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads

splitArrays ::
  (MonadBuilder m, Rep m ~ GPU) =>
  VName ->
  [VName] ->
  SplitOrdering ->
  SubExp ->
  SubExp ->
  SubExp ->
  [VName] ->
  m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
chunk_size] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
  case SplitOrdering
ordering of
    SplitOrdering
SplitContiguous -> do
      SubExp
offset <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
i SubExp
elems_per_i
      (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
    SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
  where
    contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
      Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))]
      [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
slice_name] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

    stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
      Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
      [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
slice_name] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

partitionChunkedKernelFoldParameters ::
  Int ->
  [Param dec] ->
  (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
num_accs (Param dec
i_param : Param dec
chunk_param : [Param dec]
params) =
  let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
   in (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
i_param, Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)
partitionChunkedKernelFoldParameters Int
_ [Param dec]
_ =
  String -> (VName, Param dec, [Param dec], [Param dec])
forall a. HasCallStack => String -> a
error String
"partitionChunkedKernelFoldParameters: lambda takes too few parameters"

blockedPerThread ::
  (MonadBuilder m, Rep m ~ GPU) =>
  VName ->
  SubExp ->
  KernelSize ->
  StreamOrd ->
  Lambda (Rep m) ->
  Int ->
  [VName] ->
  m ([PatElem Type], [PatElem Type])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Rep m)
-> Int
-> [VName]
-> m ([PatElem Type], [PatElem Type])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda (Rep m)
lam Int
num_nonconcat [VName]
arrs = do
  let (VName
_, Param Type
chunk_size, [], [Param Type]
arr_params) =
        Int
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
0 ([Param Type] -> (VName, Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
Lambda GPU
lam

      ordering' :: SplitOrdering
ordering' =
        case StreamOrd
ordering of
          StreamOrd
InOrder -> SplitOrdering
SplitContiguous
          StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
      red_ts :: [Type]
red_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
Lambda GPU
lam
      map_ts :: [Type]
map_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
Lambda GPU
lam

  SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
  VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays
    (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)
    ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_params)
    SplitOrdering
ordering'
    SubExp
w
    (VName -> SubExp
Var VName
thread_gtid)
    SubExp
per_thread
    [VName]
arrs

  [PatElem Type]
chunk_red_pes <- [Type] -> (Type -> m (PatElem Type)) -> m [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
red_ts ((Type -> m (PatElem Type)) -> m [PatElem Type])
-> (Type -> m (PatElem Type)) -> m [PatElem Type]
forall a b. (a -> b) -> a -> b
$ \Type
red_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
    PatElem Type -> m (PatElem Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem Type -> m (PatElem Type))
-> PatElem Type -> m (PatElem Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_name Type
red_t
  [PatElem Type]
chunk_map_pes <- [Type] -> (Type -> m (PatElem Type)) -> m [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
map_ts ((Type -> m (PatElem Type)) -> m [PatElem Type])
-> (Type -> m (PatElem Type)) -> m [PatElem Type]
forall a b. (a -> b) -> a -> b
$ \Type
map_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
    PatElem Type -> m (PatElem Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem Type -> m (PatElem Type))
-> PatElem Type -> m (PatElem Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_name (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ Type
map_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)

  let ([SubExpRes]
chunk_red_ses, [SubExpRes]
chunk_map_ses) =
        Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExpRes] -> ([SubExpRes], [SubExpRes]))
-> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a b. (a -> b) -> a -> b
$ Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
Lambda GPU
lam

  Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
    Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
Lambda GPU
lam)
      Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList
        [ Certs -> Stm GPU -> Stm GPU
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm GPU -> Stm GPU) -> Stm GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElem Type
pe, SubExpRes Certs
cs SubExp
se) <- [PatElem Type] -> [SubExpRes] -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
chunk_red_pes [SubExpRes]
chunk_red_ses
        ]
      Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList
        [ Certs -> Stm GPU -> Stm GPU
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm GPU -> Stm GPU) -> Stm GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElem Type
pe, SubExpRes Certs
cs SubExp
se) <- [PatElem Type] -> [SubExpRes] -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
chunk_map_pes [SubExpRes]
chunk_map_ses
        ]

  ([PatElem Type], [PatElem Type])
-> m ([PatElem Type], [PatElem Type])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem Type]
chunk_red_pes, [PatElem Type]
chunk_map_pes)

-- | Given a chunked fold lambda that takes its initial accumulator
-- value as parameters, bind those parameters to the neutral element
-- instead.
kerneliseLambda ::
  MonadFreshNames m =>
  [SubExp] ->
  Lambda GPU ->
  m (Lambda GPU)
kerneliseLambda :: [SubExp] -> Lambda GPU -> m (Lambda GPU)
kerneliseLambda [SubExp]
nes Lambda GPU
lam = do
  Param Type
thread_index_param <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"thread_index" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let (Param Type
fold_chunk_param, [Param Type]
fold_acc_params, [Param Type]
fold_inp_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam

      mkAccInit :: Param dec -> SubExp -> Stm rep
mkAccInit Param dec
p (Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
p =
            [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      mkAccInit Param dec
p SubExp
x = [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
      acc_init_stms :: Stms GPU
acc_init_stms = [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU) -> [Stm GPU] -> Stms GPU
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp -> Stm GPU)
-> [Param Type] -> [SubExp] -> [Stm GPU]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param Type -> SubExp -> Stm GPU
forall rep dec.
(Buildable rep, Typed dec) =>
Param dec -> SubExp -> Stm rep
mkAccInit [Param Type]
fold_acc_params [SubExp]
nes
  Lambda GPU -> m (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda GPU
lam
      { lambdaBody :: Body GPU
lambdaBody = Stms GPU -> Body GPU -> Body GPU
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Stms GPU
acc_init_stms (Body GPU -> Body GPU) -> Body GPU -> Body GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam,
        lambdaParams :: [LParam GPU]
lambdaParams = Param Type
thread_index_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: Param Type
fold_chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
fold_inp_params
      }

prepareStream ::
  (MonadBuilder m, Rep m ~ GPU) =>
  KernelSize ->
  [(VName, SubExp)] ->
  SubExp ->
  Commutativity ->
  Lambda GPU ->
  [SubExp] ->
  [VName] ->
  m (SubExp, SegSpace, [Type], KernelBody GPU)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody GPU)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm Lambda GPU
fold_lam [SubExp]
nes [VName]
arrs = do
  let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
  let (StreamOrd
ordering, SplitOrdering
split_ordering) =
        case Commutativity
comm of
          Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
          Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)

  Lambda GPU
fold_lam' <- [SubExp] -> Lambda GPU -> m (Lambda GPU)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> Lambda GPU -> m (Lambda GPU)
kerneliseLambda [SubExp]
nes Lambda GPU
fold_lam

  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
  KernelBody GPU
kbody <- (([KernelResult], Stms GPU) -> KernelBody GPU)
-> m ([KernelResult], Stms GPU) -> m (KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms GPU -> KernelBody GPU)
-> ([KernelResult], Stms GPU) -> KernelBody GPU
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms GPU -> [KernelResult] -> KernelBody GPU)
-> [KernelResult] -> Stms GPU -> KernelBody GPU
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) (m ([KernelResult], Stms GPU) -> m (KernelBody GPU))
-> m ([KernelResult], Stms GPU) -> m (KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
    Builder GPU [KernelResult] -> m ([KernelResult], Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU [KernelResult] -> m ([KernelResult], Stms GPU))
-> Builder GPU [KernelResult] -> m ([KernelResult], Stms GPU)
forall a b. (a -> b) -> a -> b
$
      Scope GPU
-> Builder GPU [KernelResult] -> Builder GPU [KernelResult]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (Builder GPU [KernelResult] -> Builder GPU [KernelResult])
-> Builder GPU [KernelResult] -> Builder GPU [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
        ([PatElem Type]
chunk_red_pes, [PatElem Type]
chunk_map_pes) <-
          VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Rep (BuilderT GPU (State VNameSource)))
-> Int
-> [VName]
-> BuilderT
     GPU (State VNameSource) ([PatElem Type], [PatElem Type])
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Rep m)
-> Int
-> [VName]
-> m ([PatElem Type], [PatElem Type])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
        let concatReturns :: PatElem Type -> KernelResult
concatReturns PatElem Type
pe =
              Certs -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns Certs
forall a. Monoid a => a
mempty SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
        [KernelResult] -> Builder GPU [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( (PatElem Type -> KernelResult) -> [PatElem Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
forall a. Monoid a => a
mempty (SubExp -> KernelResult)
-> (PatElem Type -> SubExp) -> PatElem Type -> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
chunk_red_pes
              [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ (PatElem Type -> KernelResult) -> [PatElem Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> KernelResult
concatReturns [PatElem Type]
chunk_map_pes
          )

  let ([Type]
redout_ts, [Type]
mapout_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
fold_lam
      ts :: [Type]
ts = [Type]
redout_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
mapout_ts

  (SubExp, SegSpace, [Type], KernelBody GPU)
-> m (SubExp, SegSpace, [Type], KernelBody GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_threads, SegSpace
space, [Type]
ts, KernelBody GPU
kbody)

streamRed ::
  (MonadFreshNames m, HasScope GPU m) =>
  MkSegLevel GPU m ->
  Pat Type ->
  SubExp ->
  Commutativity ->
  Lambda GPU ->
  Lambda GPU ->
  [SubExp] ->
  [VName] ->
  m (Stms GPU)
streamRed :: MkSegLevel GPU m
-> Pat Type
-> SubExp
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m (Stms GPU)
streamRed MkSegLevel GPU m
mk_lvl Pat Type
pat SubExp
w Commutativity
comm Lambda GPU
red_lam Lambda GPU
fold_lam [SubExp]
nes [VName]
arrs = BuilderT GPU m () -> m (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ (BuilderT GPU m () -> m (Stms GPU))
-> BuilderT GPU m () -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  -- The strategy here is to rephrase the stream reduction as a
  -- non-segmented SegRed that does explicit chunking within its body.
  -- First, figure out how many threads to use for this.
  KernelSize
size <- String -> SubExp -> BuilderT GPU m KernelSize
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w

  let ([PatElem Type]
redout_pes, [PatElem Type]
mapout_pes) = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElem Type] -> ([PatElem Type], [PatElem Type]))
-> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
  (Pat Type
redout_pat, [(VName, SubExp)]
ispace, BuilderT GPU m ()
read_dummy) <- Pat Type
-> BuilderT GPU m (Pat Type, [(VName, SubExp)], BuilderT GPU m ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBuilder m) =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim (Pat Type
 -> BuilderT GPU m (Pat Type, [(VName, SubExp)], BuilderT GPU m ()))
-> Pat Type
-> BuilderT GPU m (Pat Type, [(VName, SubExp)], BuilderT GPU m ())
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
redout_pes
  let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
redout_pat [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
mapout_pes

  (SubExp
_, SegSpace
kspace, [Type]
ts, KernelBody GPU
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> BuilderT GPU m (SubExp, SegSpace, [Type], KernelBody GPU)
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody GPU)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm Lambda GPU
fold_lam [SubExp]
nes [VName]
arrs

  SegLevel
lvl <- MkSegLevel GPU m
mk_lvl [SubExp
w] String
"stream_red" (ThreadRecommendation -> BuilderT GPU m (SegOpLevel GPU))
-> ThreadRecommendation -> BuilderT GPU m (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pat (LetDec (Rep (BuilderT GPU m)))
-> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT GPU m)))
pat' (Exp GPU -> BuilderT GPU m ())
-> (SegOp SegLevel GPU -> Exp GPU)
-> SegOp SegLevel GPU
-> BuilderT GPU m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> BuilderT GPU m ())
-> SegOp SegLevel GPU -> BuilderT GPU m ()
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [SegBinOp GPU]
-> [Type]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
kspace [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda GPU
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty] [Type]
ts KernelBody GPU
kbody

  BuilderT GPU m ()
read_dummy

-- Similar to streamRed, but without the last reduction.
streamMap ::
  (MonadFreshNames m, HasScope GPU m) =>
  MkSegLevel GPU m ->
  [String] ->
  [PatElem Type] ->
  SubExp ->
  Commutativity ->
  Lambda GPU ->
  [SubExp] ->
  [VName] ->
  m ((SubExp, [VName]), Stms GPU)
streamMap :: MkSegLevel GPU m
-> [String]
-> [PatElem Type]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms GPU)
streamMap MkSegLevel GPU m
mk_lvl [String]
out_desc [PatElem Type]
mapout_pes SubExp
w Commutativity
comm Lambda GPU
fold_lam [SubExp]
nes [VName]
arrs = BuilderT GPU m (SubExp, [VName]) -> m ((SubExp, [VName]), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU m (SubExp, [VName])
 -> m ((SubExp, [VName]), Stms GPU))
-> BuilderT GPU m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  KernelSize
size <- String -> SubExp -> BuilderT GPU m KernelSize
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w

  (SubExp
threads, SegSpace
kspace, [Type]
ts, KernelBody GPU
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> BuilderT GPU m (SubExp, SegSpace, [Type], KernelBody GPU)
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody GPU)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm Lambda GPU
fold_lam [SubExp]
nes [VName]
arrs

  let redout_ts :: [Type]
redout_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts

  [PatElem Type]
redout_pes <- [(String, Type)]
-> ((String, Type) -> BuilderT GPU m (PatElem Type))
-> BuilderT GPU m [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [Type]
redout_ts) (((String, Type) -> BuilderT GPU m (PatElem Type))
 -> BuilderT GPU m [PatElem Type])
-> ((String, Type) -> BuilderT GPU m (PatElem Type))
-> BuilderT GPU m [PatElem Type]
forall a b. (a -> b) -> a -> b
$ \(String
desc, Type
t) ->
    VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (VName -> Type -> PatElem Type)
-> BuilderT GPU m VName -> BuilderT GPU m (Type -> PatElem Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BuilderT GPU m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BuilderT GPU m (Type -> PatElem Type)
-> BuilderT GPU m Type -> BuilderT GPU m (PatElem Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> BuilderT GPU m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)

  let pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [PatElem Type]
redout_pes [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
mapout_pes
  SegLevel
lvl <- MkSegLevel GPU m
mk_lvl [SubExp
w] String
"stream_map" (ThreadRecommendation -> BuilderT GPU m (SegOpLevel GPU))
-> ThreadRecommendation -> BuilderT GPU m (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pat (LetDec (Rep (BuilderT GPU m)))
-> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT GPU m)))
pat (Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m ())
-> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
ts KernelBody GPU
kbody

  (SubExp, [VName]) -> BuilderT GPU m (SubExp, [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
threads, (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
redout_pes)

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: MonadFreshNames m => MkSegLevel GPU m
segThreadCapped :: MkSegLevel GPU m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w <-
    String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nest_size"
      (Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
  SubExp
group_size <- String -> SizeClass -> BuilderT GPU m SubExp
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <-
        String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"segmap_usable_groups"
          (Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
            (SubExp -> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
            (SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU m (Exp GPU))
-> BuilderT GPU m SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
      SegLevel -> BuilderT GPU m SegLevel
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
    NoRecommendation SegVirt
v -> do
      (SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BuilderT GPU m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
      SegLevel -> BuilderT GPU m SegLevel
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v