{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Pass.ExtractKernels.StreamKernel
  ( segThreadCapped
  , streamRed
  , streamMap
  )
  where

import Control.Monad
import Control.Monad.Writer
import Data.List ()

import Prelude hiding (quot)

import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Kernels
       hiding (Prog, Body, Stm, Pattern, PatElem,
               BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.MonadFreshNames
import Futhark.Tools

data KernelSize = KernelSize { KernelSize -> SubExp
kernelElementsPerThread :: SubExp
                               -- ^ Int64
                             , KernelSize -> SubExp
kernelNumThreads :: SubExp
                               -- ^ Int32
                             }
                deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)

numberOfGroups :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
                  String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size = do
  Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
  SubExp
num_groups <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w64 Name
max_num_groups_key SubExp
group_size
  SubExp
num_threads <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_groups, SubExp
num_threads)

blockedKernelSize :: (MonadBinder m, Lore m ~ Kernels) =>
                     String -> SubExp -> m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
  SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  SubExp
w64 <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"w64" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
w
  (SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size

  SubExp
per_thread_elements <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_thread_elements" (ExpT Kernels -> m SubExp) -> m (ExpT Kernels) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64) (SubExp -> m (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> m (ExpT Kernels)) -> m SubExp -> m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
num_threads)

  KernelSize -> m KernelSize
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads

splitArrays :: (MonadBinder m, Lore m ~ Kernels) =>
               VName -> [VName]
            -> SplitOrdering -> SubExp -> SubExp -> SubExp -> [VName]
            -> m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
chunk_size] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
  case SplitOrdering
ordering of
    SplitOrdering
SplitContiguous     -> do
      SubExp
offset <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
i SubExp
elems_per_i
      (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
    SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
  where contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
          Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))]
          [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

        stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
          Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
          [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

partitionChunkedKernelFoldParameters :: Int -> [Param dec]
                                     -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
num_accs (Param dec
i_param : Param dec
chunk_param : [Param dec]
params) =
  let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
  in (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
i_param, Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)
partitionChunkedKernelFoldParameters Int
_ [Param dec]
_ =
  String -> (VName, Param dec, [Param dec], [Param dec])
forall a. HasCallStack => String -> a
error String
"partitionChunkedKernelFoldParameters: lambda takes too few parameters"

blockedPerThread :: (MonadBinder m, Lore m ~ Kernels) =>
                    VName -> SubExp -> KernelSize -> StreamOrd -> Lambda (Lore m)
                 -> Int -> [VName]
                 -> m ([PatElemT Type], [PatElemT Type])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda (Lore m)
lam Int
num_nonconcat [VName]
arrs = do
  let (VName
_, Param Type
chunk_size, [], [Param Type]
arr_params) =
        Int
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
0 ([Param Type] -> (VName, Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
LambdaT Kernels
lam

      ordering' :: SplitOrdering
ordering' =
        case StreamOrd
ordering of StreamOrd
InOrder -> SplitOrdering
SplitContiguous
                         StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
      red_ts :: [Type]
red_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
      map_ts :: [Type]
map_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam

  SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
  VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_params) SplitOrdering
ordering' SubExp
w
    (VName -> SubExp
Var VName
thread_gtid) SubExp
per_thread [VName]
arrs

  [PatElemT Type]
chunk_red_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
red_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
red_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
    PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name Type
red_t
  [PatElemT Type]
chunk_map_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
map_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
map_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
    PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type
map_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)

  let ([SubExp]
chunk_red_ses, [SubExp]
chunk_map_ses) =
        Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT Kernels -> [SubExp]) -> BodyT Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam

  Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    [ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    | (PatElemT Type
pe,SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_red_pes [SubExp]
chunk_red_ses ] Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    [ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    | (PatElemT Type
pe,SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_map_pes [SubExp]
chunk_map_ses ]

  ([PatElemT Type], [PatElemT Type])
-> m ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes)

-- | Given a chunked fold lambda that takes its initial accumulator
-- value as parameters, bind those parameters to the neutral element
-- instead.
kerneliseLambda :: MonadFreshNames m =>
                   [SubExp] -> Lambda Kernels -> m (Lambda Kernels)
kerneliseLambda :: [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
lam = do
  VName
thread_index <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"thread_index"
  let thread_index_param :: Param Type
thread_index_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
thread_index (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32
      (Param Type
fold_chunk_param, [Param Type]
fold_acc_params, [Param Type]
fold_inp_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam

      mkAccInit :: Param dec -> SubExp -> Stm lore
mkAccInit Param dec
p (Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
p =
            [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      mkAccInit Param dec
p SubExp
x = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
      acc_init_bnds :: Stms Kernels
acc_init_bnds = [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp -> Stm Kernels)
-> [Param Type] -> [SubExp] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param Type -> SubExp -> Stm Kernels
forall lore dec.
(Bindable lore, Typed dec) =>
Param dec -> SubExp -> Stm lore
mkAccInit [Param Type]
fold_acc_params [SubExp]
nes
  LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return LambdaT Kernels
lam { lambdaBody :: BodyT Kernels
lambdaBody = Stms Kernels -> BodyT Kernels -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Stms Kernels
acc_init_bnds (BodyT Kernels -> BodyT Kernels) -> BodyT Kernels -> BodyT Kernels
forall a b. (a -> b) -> a -> b
$
                            LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam
             , lambdaParams :: [LParam Kernels]
lambdaParams = Param Type
thread_index_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
                              Param Type
fold_chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
                              [Param Type]
fold_inp_params
             }

prepareStream :: (MonadBinder m, Lore m ~ Kernels) =>
                 KernelSize
              -> [(VName, SubExp)]
              -> SubExp
              -> Commutativity
              -> Lambda Kernels
              -> [SubExp]
              -> [VName]
              -> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = do
  let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
  let (StreamOrd
ordering, SplitOrdering
split_ordering) =
        case Commutativity
comm of Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
                     Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)

  LambdaT Kernels
fold_lam' <- [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
fold_lam

  SubExp
elems_per_thread_32 <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 SubExp
elems_per_thread

  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
  KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
           Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes) <-
      VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore (BinderT Kernels (State VNameSource)))
-> Int
-> [VName]
-> BinderT
     Kernels (State VNameSource) ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda (Lore (BinderT Kernels (State VNameSource)))
LambdaT Kernels
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
    let concatReturns :: PatElemT Type -> KernelResult
concatReturns PatElemT Type
pe =
          SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread_32 (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
    [KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify (SubExp -> KernelResult)
-> (PatElemT Type -> SubExp) -> PatElemT Type -> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
chunk_red_pes [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++
            (PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> KernelResult
concatReturns [PatElemT Type]
chunk_map_pes)

  let ([Type]
redout_ts, [Type]
mapout_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT Kernels
fold_lam
      ts :: [Type]
ts = [Type]
redout_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
mapout_ts

  (SubExp, SegSpace, [Type], KernelBody Kernels)
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SegSpace
space, [Type]
ts, KernelBody Kernels
kbody)

streamRed :: (MonadFreshNames m, HasScope Kernels m) =>
             MkSegLevel Kernels m
          -> Pattern Kernels
          -> SubExp
          -> Commutativity
          -> Lambda Kernels -> Lambda Kernels
          -> [SubExp] -> [VName]
          -> m (Stms Kernels)
streamRed :: MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (Stms Kernels)
streamRed MkSegLevel Kernels m
mk_lvl Pattern Kernels
pat SubExp
w Commutativity
comm LambdaT Kernels
red_lam LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  -- The strategy here is to rephrase the stream reduction as a
  -- non-segmented SegRed that does explicit chunking within its body.
  -- First, figure out how many threads to use for this.
  KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w

  let ([PatElemT Type]
redout_pes, [PatElemT Type]
mapout_pes) = Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat
  (PatternT Type
redout_pat, [(VName, SubExp)]
ispace, BinderT Kernels m ()
read_dummy) <- Pattern (Lore (BinderT Kernels m))
-> BinderT
     Kernels
     m
     (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
      BinderT Kernels m ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim (Pattern (Lore (BinderT Kernels m))
 -> BinderT
      Kernels
      m
      (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
       BinderT Kernels m ()))
-> Pattern (Lore (BinderT Kernels m))
-> BinderT
     Kernels
     m
     (Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
      BinderT Kernels m ())
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
redout_pes
  let pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
redout_pat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
mapout_pes

  (SubExp
_, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs

  SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_red" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat' (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp Kernels]
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegLevel
lvl SegSpace
kspace
    [Commutativity
-> LambdaT Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm LambdaT Kernels
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty] [Type]
ts KernelBody Kernels
kbody

  BinderT Kernels m ()
read_dummy

-- Similar to streamRed, but without the last reduction.
streamMap :: (MonadFreshNames m, HasScope Kernels m) =>
              MkSegLevel Kernels m
          -> [String] -> [PatElem Kernels] -> SubExp
           -> Commutativity -> Lambda Kernels -> [SubExp] -> [VName]
           -> m ((SubExp, [VName]), Stms Kernels)
streamMap :: MkSegLevel Kernels m
-> [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap MkSegLevel Kernels m
mk_lvl [String]
out_desc [PatElem Kernels]
mapout_pes SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (SubExp, [VName])
 -> m ((SubExp, [VName]), Stms Kernels))
-> BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w

  (SubExp
threads, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs

  let redout_ts :: [Type]
redout_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts

  [PatElemT Type]
redout_pes <- [(String, Type)]
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [Type]
redout_ts) (((String, Type) -> BinderT Kernels m (PatElemT Type))
 -> BinderT Kernels m [PatElemT Type])
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \(String
desc, Type
t) ->
    VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem (VName -> Type -> PatElemT Type)
-> BinderT Kernels m VName
-> BinderT Kernels m (Type -> PatElemT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BinderT Kernels m (Type -> PatElemT Type)
-> BinderT Kernels m Type -> BinderT Kernels m (PatElemT Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> BinderT Kernels m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)

  let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [PatElemT Type]
redout_pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
[PatElem Kernels]
mapout_pes
  SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_map" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
ts KernelBody Kernels
kbody

  (SubExp, [VName]) -> BinderT Kernels m (SubExp, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
threads, (PatElemT Type -> VName) -> [PatElemT Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT Type]
redout_pes)

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped :: MkSegLevel Kernels m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w64 <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"nest_size" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         BinOp
-> SubExp
-> [SubExp]
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) ([SubExp] -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m [SubExp] -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         (SubExp -> BinderT Kernels m SubExp)
-> [SubExp] -> BinderT Kernels m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64) [SubExp]
ws
  SubExp
group_size <- String -> SizeClass -> BinderT Kernels m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups" (ExpT Kernels -> BinderT Kernels m SubExp)
-> (SubExp -> ExpT Kernels) -> SubExp -> BinderT Kernels m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                       BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels)
-> (SubExp -> BasicOp) -> SubExp -> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32) (SubExp -> BinderT Kernels m SubExp)
-> BinderT Kernels m SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                       String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups_64" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                       BinOp
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64)
                       (SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
    NoRecommendation SegVirt
v -> do
      (SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BinderT Kernels m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v