{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.BlockedKernel
       ( MkSegLevel
       , ThreadRecommendation(..)
       , segRed
       , nonSegRed
       , segScan
       , segHist
       , segMap

       , streamRed
       , streamMap

       , mapKernel
       , KernelInput(..)
       , readKernelInput

       , soacsLambdaToKernels
       , soacsStmToKernels
       , scopeForKernels
       , scopeForSOACs

       , getSize
       , segThread
       , segThreadCapped
       , mkSegSpace
       )
       where

import Control.Monad
import Control.Monad.Writer
import Control.Monad.Identity
import Data.List ()

import Prelude hiding (quot)

import Futhark.Analysis.PrimExp
import Futhark.Analysis.Rephrase
import Futhark.Representation.AST
import Futhark.Representation.SOACS (SOACS)
import qualified Futhark.Representation.SOACS.SOAC as SOAC
import Futhark.Representation.Kernels
       hiding (Prog, Body, Stm, Pattern, PatElem,
               BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename

getSize :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
           String -> SizeClass -> m SubExp
getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
  Name
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
desc (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

numberOfGroups :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
                  String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size = do
  Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
  SubExp
num_groups <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w64 Name
max_num_groups_key SubExp
group_size
  SubExp
num_threads <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
num_groups SubExp
group_size
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_groups, SubExp
num_threads)

segThread :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
             String -> m SegLevel
segThread :: String -> m SegLevel
segThread String
desc =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
    (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> m SubExp -> m (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_num_groups") SizeClass
SizeNumGroups)
    m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> m SubExp -> m (Count GroupSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup)
    m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
SegVirt

data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt

type MkSegLevel m =
  [SubExp] -> String -> ThreadRecommendation -> BinderT Kernels m SegLevel

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: MonadFreshNames m => MkSegLevel m
segThreadCapped :: MkSegLevel m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w64 <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"nest_size" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         BinOp
-> SubExp
-> [SubExp]
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> BinOp
Mul IntType
Int64) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) ([SubExp] -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m [SubExp] -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         (SubExp -> BinderT Kernels m SubExp)
-> [SubExp] -> BinderT Kernels m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64) [SubExp]
ws
  SubExp
group_size <- String -> SizeClass -> BinderT Kernels m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups" (ExpT Kernels -> BinderT Kernels m SubExp)
-> (SubExp -> ExpT Kernels) -> SubExp -> BinderT Kernels m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                       BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> (SubExp -> BasicOp Kernels) -> SubExp -> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> SubExp -> BasicOp Kernels
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32) (SubExp -> BinderT Kernels m SubExp)
-> BinderT Kernels m SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                       String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups_64" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                       IntType
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp IntType
Int64 (SubExp -> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64)
                       (SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
    NoRecommendation SegVirt
v -> do
      (SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BinderT Kernels m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size
      SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v

mkSegSpace :: MonadFreshNames m => [(VName, SubExp)] -> m SegSpace
mkSegSpace :: [(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
dims = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (VName -> [(VName, SubExp)] -> SegSpace)
-> m VName -> m ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid" m ([(VName, SubExp)] -> SegSpace)
-> m [(VName, SubExp)] -> m SegSpace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VName, SubExp)] -> m [(VName, SubExp)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
dims

-- | Given a chunked fold lambda that takes its initial accumulator
-- value as parameters, bind those parameters to the neutral element
-- instead.
kerneliseLambda :: MonadFreshNames m =>
                   [SubExp] -> Lambda Kernels -> m (Lambda Kernels)
kerneliseLambda :: [SubExp] -> Lambda Kernels -> m (Lambda Kernels)
kerneliseLambda [SubExp]
nes Lambda Kernels
lam = do
  VName
thread_index <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"thread_index"
  let thread_index_param :: Param (TypeBase Shape NoUniqueness)
thread_index_param = VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall attr. VName -> attr -> Param attr
Param VName
thread_index (TypeBase Shape NoUniqueness
 -> Param (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32
      (Param (TypeBase Shape NoUniqueness)
fold_chunk_param, [Param (TypeBase Shape NoUniqueness)]
fold_acc_params, [Param (TypeBase Shape NoUniqueness)]
fold_inp_params) =
        Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall attr.
Int -> [Param attr] -> (Param attr, [Param attr], [Param attr])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (TypeBase Shape NoUniqueness)]
 -> (Param (TypeBase Shape NoUniqueness),
     [Param (TypeBase Shape NoUniqueness)],
     [Param (TypeBase Shape NoUniqueness)]))
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam

      mkAccInit :: Param attr -> SubExp -> Stm lore
mkAccInit Param attr
p (Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param attr -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType Param attr
p =
            [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param attr -> Ident
forall attr. Typed attr => Param attr -> Ident
paramIdent Param attr
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v
      mkAccInit Param attr
p SubExp
x = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param attr -> Ident
forall attr. Typed attr => Param attr -> Ident
paramIdent Param attr
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
x
      acc_init_bnds :: Stms Kernels
acc_init_bnds = [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> SubExp -> Stm Kernels)
-> [Param (TypeBase Shape NoUniqueness)]
-> [SubExp]
-> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (TypeBase Shape NoUniqueness) -> SubExp -> Stm Kernels
forall lore attr.
(Bindable lore, Typed attr) =>
Param attr -> SubExp -> Stm lore
mkAccInit [Param (TypeBase Shape NoUniqueness)]
fold_acc_params [SubExp]
nes
  Lambda Kernels -> m (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda Kernels
lam { lambdaBody :: BodyT Kernels
lambdaBody = Stms Kernels -> BodyT Kernels -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Stms Kernels
acc_init_bnds (BodyT Kernels -> BodyT Kernels) -> BodyT Kernels -> BodyT Kernels
forall a b. (a -> b) -> a -> b
$
                            Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam
             , lambdaParams :: [LParam Kernels]
lambdaParams = Param (TypeBase Shape NoUniqueness)
thread_index_param Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
:
                              Param (TypeBase Shape NoUniqueness)
fold_chunk_param Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
:
                              [Param (TypeBase Shape NoUniqueness)]
fold_inp_params
             }

prepareRedOrScan :: (MonadBinder m, Lore m ~ Kernels) =>
                    SubExp
                 -> Lambda Kernels
                 -> [VName] -> [(VName, SubExp)] -> [KernelInput]
                 -> m (SegSpace, KernelBody Kernels)
prepareRedOrScan :: SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody Kernels)
prepareRedOrScan SubExp
w Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = do
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
w)]
  KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
           Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
    [(Param (TypeBase Shape NoUniqueness), VName)]
-> ((Param (TypeBase Shape NoUniqueness), VName)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [VName] -> [(Param (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam) [VName]
arrs) (((Param (TypeBase Shape NoUniqueness), VName)
  -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness), VName)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, VName
arr) -> do
      TypeBase Shape NoUniqueness
arr_t <- VName
-> BinderT
     Kernels (State VNameSource) (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param (TypeBase Shape NoUniqueness) -> VName
forall attr. Param attr -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr (Slice SubExp -> BasicOp Kernels)
-> Slice SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
    (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult])
-> BinderT Kernels (State VNameSource) [SubExp]
-> Binder Kernels [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
map_lam)

  (SegSpace, KernelBody Kernels) -> m (SegSpace, KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegSpace
space, KernelBody Kernels
kbody)

segRed :: (MonadFreshNames m, HasScope Kernels m) =>
          SegLevel
       -> Pattern Kernels
       -> SubExp -- segment size
       -> [SegRedOp Kernels]
       -> Lambda Kernels
       -> [VName]
       -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this reduction
       -> [KernelInput]     -- inps = inputs that can be looked up by using the gtids from ispace
       -> m (Stms Kernels)
segRed :: SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segRed SegLevel
lvl Pattern Kernels
pat SubExp
w [SegRedOp Kernels]
ops Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody Kernels
kbody) <- SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (State VNameSource) (SegSpace, KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody Kernels)
prepareRedOrScan SubExp
w Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [SegRedOp Kernels]
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegRed SegLevel
lvl SegSpace
kspace [SegRedOp Kernels]
ops (Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
map_lam) KernelBody Kernels
kbody

segScan :: (MonadFreshNames m, HasScope Kernels m) =>
           SegLevel
        -> Pattern Kernels
        -> SubExp -- segment size
        -> Lambda Kernels -> Lambda Kernels
        -> [SubExp] -> [VName]
        -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this scan
        -> [KernelInput]     -- inps = inputs that can be looked up by using the gtids from ispace
        -> m (Stms Kernels)
segScan :: SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segScan SegLevel
lvl Pattern Kernels
pat SubExp
w Lambda Kernels
scan_lam Lambda Kernels
map_lam [SubExp]
nes [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody Kernels
kbody) <- SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (State VNameSource) (SegSpace, KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody Kernels)
prepareRedOrScan SubExp
w Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> Lambda Kernels
-> [SubExp]
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> Lambda lore
-> [SubExp]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegScan SegLevel
lvl SegSpace
kspace Lambda Kernels
scan_lam [SubExp]
nes (Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
map_lam) KernelBody Kernels
kbody

segMap :: (MonadFreshNames m, HasScope Kernels m) =>
          SegLevel
       -> Pattern Kernels
       -> SubExp -- segment size
       -> Lambda Kernels
       -> [VName]
       -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this map
       -> [KernelInput]     -- inps = inputs that can be looked up by using the gtids from ispace
       -> m (Stms Kernels)
segMap :: SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segMap SegLevel
lvl Pattern Kernels
pat SubExp
w Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody Kernels
kbody) <- SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (State VNameSource) (SegSpace, KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody Kernels)
prepareRedOrScan SubExp
w Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegMap SegLevel
lvl SegSpace
kspace (Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
map_lam) KernelBody Kernels
kbody

dummyDim :: (MonadFreshNames m, MonadBinder m) =>
            Pattern Kernels
         -> m (Pattern Kernels, [(VName, SubExp)], m ())
dummyDim :: Pattern Kernels -> m (Pattern Kernels, [(VName, SubExp)], m ())
dummyDim Pattern Kernels
pat = do
  -- We add a unit-size segment on top to ensure that the result
  -- of the SegRed is an array, which we then immediately index.
  -- This is useful in the case that the value is used on the
  -- device afterwards, as this may save an expensive
  -- host-device copy (scalars are kept on the host, but arrays
  -- may be on the device).
  let addDummyDim :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDummyDim TypeBase Shape NoUniqueness
t = TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1
  PatternT (TypeBase Shape NoUniqueness)
pat' <- (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> PatternT (TypeBase Shape NoUniqueness)
-> PatternT (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDummyDim (PatternT (TypeBase Shape NoUniqueness)
 -> PatternT (TypeBase Shape NoUniqueness))
-> m (PatternT (TypeBase Shape NoUniqueness))
-> m (PatternT (TypeBase Shape NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT (TypeBase Shape NoUniqueness)
-> m (PatternT (TypeBase Shape NoUniqueness))
forall attr (m :: * -> *).
(Rename attr, MonadFreshNames m) =>
PatternT attr -> m (PatternT attr)
renamePattern PatternT (TypeBase Shape NoUniqueness)
Pattern Kernels
pat
  VName
dummy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"dummy"
  let ispace :: [(VName, SubExp)]
ispace = [(VName
dummy, IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]

  (PatternT (TypeBase Shape NoUniqueness), [(VName, SubExp)], m ())
-> m (PatternT (TypeBase Shape NoUniqueness), [(VName, SubExp)],
      m ())
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT (TypeBase Shape NoUniqueness)
pat', [(VName, SubExp)]
ispace,
          [(VName, VName)] -> ((VName, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT (TypeBase Shape NoUniqueness) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT (TypeBase Shape NoUniqueness)
pat') (PatternT (TypeBase Shape NoUniqueness) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT (TypeBase Shape NoUniqueness)
Pattern Kernels
pat)) (((VName, VName) -> m ()) -> m ())
-> ((VName, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) -> do
             TypeBase Shape NoUniqueness
from_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
from
             [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
to] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp (Lore m)
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
from (Slice SubExp -> BasicOp (Lore m))
-> Slice SubExp -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$
               TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
from_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0])

nonSegRed :: (MonadFreshNames m, HasScope Kernels m) =>
             SegLevel
          -> Pattern Kernels
          -> SubExp
          -> [SegRedOp Kernels]
          -> Lambda Kernels
          -> [VName]
          -> m (Stms Kernels)
nonSegRed :: SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> m (Stms Kernels)
nonSegRed SegLevel
lvl Pattern Kernels
pat SubExp
w [SegRedOp Kernels]
ops Lambda Kernels
map_lam [VName]
arrs = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  (PatternT (TypeBase Shape NoUniqueness)
pat', [(VName, SubExp)]
ispace, BinderT Kernels (State VNameSource) ()
read_dummy) <- Pattern Kernels
-> BinderT
     Kernels
     (State VNameSource)
     (Pattern Kernels, [(VName, SubExp)],
      BinderT Kernels (State VNameSource) ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m) =>
Pattern Kernels -> m (Pattern Kernels, [(VName, SubExp)], m ())
dummyDim Pattern Kernels
pat
  Stms Kernels -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segRed SegLevel
lvl PatternT (TypeBase Shape NoUniqueness)
Pattern Kernels
pat' SubExp
w [SegRedOp Kernels]
ops Lambda Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace []
  BinderT Kernels (State VNameSource) ()
read_dummy

prepareStream :: (MonadBinder m, Lore m ~ Kernels) =>
                 KernelSize
              -> [(VName, SubExp)]
              -> SubExp
              -> Commutativity
              -> Lambda Kernels
              -> [SubExp]
              -> [VName]
              -> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm Lambda Kernels
fold_lam [SubExp]
nes [VName]
arrs = do
  let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
  let (StreamOrd
ordering, SplitOrdering
split_ordering) =
        case Commutativity
comm of Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
                     Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)

  Lambda Kernels
fold_lam' <- [SubExp] -> Lambda Kernels -> m (Lambda Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> Lambda Kernels -> m (Lambda Kernels)
kerneliseLambda [SubExp]
nes Lambda Kernels
fold_lam

  SubExp
elems_per_thread_32 <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 SubExp
elems_per_thread

  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
  KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
           Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    ([PatElemT (TypeBase Shape NoUniqueness)]
chunk_red_pes, [PatElemT (TypeBase Shape NoUniqueness)]
chunk_map_pes) <-
      VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda Kernels
-> Int
-> [VName]
-> BinderT
     Kernels (State VNameSource) ([PatElem Kernels], [PatElem Kernels])
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda Kernels
-> Int
-> [VName]
-> m ([PatElem Kernels], [PatElem Kernels])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda Kernels
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
    let concatReturns :: PatElemT (TypeBase Shape NoUniqueness) -> KernelResult
concatReturns PatElemT (TypeBase Shape NoUniqueness)
pe =
          SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread_32 (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElemT (TypeBase Shape NoUniqueness) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (TypeBase Shape NoUniqueness)
pe
    [KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (TypeBase Shape NoUniqueness) -> KernelResult)
-> [PatElemT (TypeBase Shape NoUniqueness)] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify (SubExp -> KernelResult)
-> (PatElemT (TypeBase Shape NoUniqueness) -> SubExp)
-> PatElemT (TypeBase Shape NoUniqueness)
-> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT (TypeBase Shape NoUniqueness) -> VName)
-> PatElemT (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (TypeBase Shape NoUniqueness) -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT (TypeBase Shape NoUniqueness)]
chunk_red_pes [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++
            (PatElemT (TypeBase Shape NoUniqueness) -> KernelResult)
-> [PatElemT (TypeBase Shape NoUniqueness)] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (TypeBase Shape NoUniqueness) -> KernelResult
concatReturns [PatElemT (TypeBase Shape NoUniqueness)]
chunk_map_pes)

  let ([TypeBase Shape NoUniqueness]
redout_ts, [TypeBase Shape NoUniqueness]
mapout_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
fold_lam
      ts :: [TypeBase Shape NoUniqueness]
ts = [TypeBase Shape NoUniqueness]
redout_ts [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [TypeBase Shape NoUniqueness]
mapout_ts

  (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
 KernelBody Kernels)
-> m (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SegSpace
space, [TypeBase Shape NoUniqueness]
ts, KernelBody Kernels
kbody)

streamRed :: (MonadFreshNames m, HasScope Kernels m) =>
             Pattern Kernels
          -> SubExp
          -> Commutativity
          -> Lambda Kernels -> Lambda Kernels
          -> [SubExp] -> [VName]
          -> m (Stms Kernels)
streamRed :: Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m (Stms Kernels)
streamRed Pattern Kernels
pat SubExp
w Commutativity
comm Lambda Kernels
red_lam Lambda Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  -- The strategy here is to rephrase the stream reduction as a
  -- non-segmented SegRed that does explicit chunking within its body.
  -- First, figure out how many threads to use for this.
  KernelSize
size <- String -> SubExp -> BinderT Kernels (State VNameSource) KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w

  let ([PatElemT (TypeBase Shape NoUniqueness)]
redout_pes, [PatElemT (TypeBase Shape NoUniqueness)]
mapout_pes) = Int
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> ([PatElemT (TypeBase Shape NoUniqueness)],
    [PatElemT (TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT (TypeBase Shape NoUniqueness)]
 -> ([PatElemT (TypeBase Shape NoUniqueness)],
     [PatElemT (TypeBase Shape NoUniqueness)]))
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> ([PatElemT (TypeBase Shape NoUniqueness)],
    [PatElemT (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ PatternT (TypeBase Shape NoUniqueness)
-> [PatElemT (TypeBase Shape NoUniqueness)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (TypeBase Shape NoUniqueness)
Pattern Kernels
pat
  (PatternT (TypeBase Shape NoUniqueness)
redout_pat, [(VName, SubExp)]
ispace, BinderT Kernels (State VNameSource) ()
read_dummy) <- Pattern Kernels
-> BinderT
     Kernels
     (State VNameSource)
     (Pattern Kernels, [(VName, SubExp)],
      BinderT Kernels (State VNameSource) ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m) =>
Pattern Kernels -> m (Pattern Kernels, [(VName, SubExp)], m ())
dummyDim (Pattern Kernels
 -> BinderT
      Kernels
      (State VNameSource)
      (Pattern Kernels, [(VName, SubExp)],
       BinderT Kernels (State VNameSource) ()))
-> Pattern Kernels
-> BinderT
     Kernels
     (State VNameSource)
     (Pattern Kernels, [(VName, SubExp)],
      BinderT Kernels (State VNameSource) ())
forall a b. (a -> b) -> a -> b
$ [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (TypeBase Shape NoUniqueness)]
redout_pes
  let pat' :: PatternT (TypeBase Shape NoUniqueness)
pat' = [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT (TypeBase Shape NoUniqueness)]
 -> PatternT (TypeBase Shape NoUniqueness))
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PatternT (TypeBase Shape NoUniqueness)
-> [PatElemT (TypeBase Shape NoUniqueness)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (TypeBase Shape NoUniqueness)
redout_pat [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (TypeBase Shape NoUniqueness)]
mapout_pes

  (SubExp
_, SegSpace
kspace, [TypeBase Shape NoUniqueness]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> BinderT
     Kernels
     (State VNameSource)
     (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm Lambda Kernels
fold_lam [SubExp]
nes [VName]
arrs

  SegLevel
lvl <- MkSegLevel (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [SubExp
w] String
"stream_red" (ThreadRecommendation
 -> BinderT Kernels (State VNameSource) SegLevel)
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT (TypeBase Shape NoUniqueness)
Pattern (Lore (BinderT Kernels (State VNameSource)))
pat' (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegRedOp Kernels]
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegRed SegLevel
lvl SegSpace
kspace
    [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegRedOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm Lambda Kernels
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty] [TypeBase Shape NoUniqueness]
ts KernelBody Kernels
kbody

  BinderT Kernels (State VNameSource) ()
read_dummy

-- Similar to streamRed, but without the last reduction.
streamMap :: (MonadFreshNames m, HasScope Kernels m) =>
              [String] -> [PatElem Kernels] -> SubExp
           -> Commutativity -> Lambda Kernels -> [SubExp] -> [VName]
           -> m ((SubExp, [VName]), Stms Kernels)
streamMap :: [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap [String]
out_desc [PatElem Kernels]
mapout_pes SubExp
w Commutativity
comm Lambda Kernels
fold_lam [SubExp]
nes [VName]
arrs = Binder Kernels (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, [VName])
 -> m ((SubExp, [VName]), Stms Kernels))
-> Binder Kernels (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  KernelSize
size <- String -> SubExp -> BinderT Kernels (State VNameSource) KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w

  (SubExp
threads, SegSpace
kspace, [TypeBase Shape NoUniqueness]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> BinderT
     Kernels
     (State VNameSource)
     (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [TypeBase Shape NoUniqueness],
      KernelBody Kernels)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm Lambda Kernels
fold_lam [SubExp]
nes [VName]
arrs

  let redout_ts :: [TypeBase Shape NoUniqueness]
redout_ts = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [TypeBase Shape NoUniqueness]
ts

  [PatElemT (TypeBase Shape NoUniqueness)]
redout_pes <- [(String, TypeBase Shape NoUniqueness)]
-> ((String, TypeBase Shape NoUniqueness)
    -> BinderT
         Kernels
         (State VNameSource)
         (PatElemT (TypeBase Shape NoUniqueness)))
-> BinderT
     Kernels
     (State VNameSource)
     [PatElemT (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String]
-> [TypeBase Shape NoUniqueness]
-> [(String, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [TypeBase Shape NoUniqueness]
redout_ts) (((String, TypeBase Shape NoUniqueness)
  -> BinderT
       Kernels
       (State VNameSource)
       (PatElemT (TypeBase Shape NoUniqueness)))
 -> BinderT
      Kernels
      (State VNameSource)
      [PatElemT (TypeBase Shape NoUniqueness)])
-> ((String, TypeBase Shape NoUniqueness)
    -> BinderT
         Kernels
         (State VNameSource)
         (PatElemT (TypeBase Shape NoUniqueness)))
-> BinderT
     Kernels
     (State VNameSource)
     [PatElemT (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \(String
desc, TypeBase Shape NoUniqueness
t) ->
    VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall attr. VName -> attr -> PatElemT attr
PatElem (VName
 -> TypeBase Shape NoUniqueness
 -> PatElemT (TypeBase Shape NoUniqueness))
-> BinderT Kernels (State VNameSource) VName
-> BinderT
     Kernels
     (State VNameSource)
     (TypeBase Shape NoUniqueness
      -> PatElemT (TypeBase Shape NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BinderT
  Kernels
  (State VNameSource)
  (TypeBase Shape NoUniqueness
   -> PatElemT (TypeBase Shape NoUniqueness))
-> BinderT
     Kernels (State VNameSource) (TypeBase Shape NoUniqueness)
-> BinderT
     Kernels
     (State VNameSource)
     (PatElemT (TypeBase Shape NoUniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase Shape NoUniqueness
-> BinderT
     Kernels (State VNameSource) (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)

  let pat :: PatternT (TypeBase Shape NoUniqueness)
pat = [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT (TypeBase Shape NoUniqueness)]
 -> PatternT (TypeBase Shape NoUniqueness))
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ [PatElemT (TypeBase Shape NoUniqueness)]
redout_pes [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (TypeBase Shape NoUniqueness)]
[PatElem Kernels]
mapout_pes
  SegLevel
lvl <- MkSegLevel (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [SubExp
w] String
"stream_map" (ThreadRecommendation
 -> BinderT Kernels (State VNameSource) SegLevel)
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT (TypeBase Shape NoUniqueness)
Pattern (Lore (BinderT Kernels (State VNameSource)))
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegMap SegLevel
lvl SegSpace
kspace [TypeBase Shape NoUniqueness]
ts KernelBody Kernels
kbody

  (SubExp, [VName]) -> Binder Kernels (SubExp, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
threads, (PatElemT (TypeBase Shape NoUniqueness) -> VName)
-> [PatElemT (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (TypeBase Shape NoUniqueness) -> VName
forall attr. PatElemT attr -> VName
patElemName [PatElemT (TypeBase Shape NoUniqueness)]
redout_pes)

segHist :: (MonadFreshNames m, HasScope Kernels m) =>
             SegLevel
          -> Pattern Kernels
          -> SubExp
          -> [(VName,SubExp)] -- ^ Segment indexes and sizes.
          -> [KernelInput]
          -> [HistOp Kernels]
          -> Lambda Kernels -> [VName]
          -> m (Stms Kernels)
segHist :: SegLevel
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> Lambda Kernels
-> [VName]
-> m (Stms Kernels)
segHist SegLevel
lvl Pattern Kernels
pat SubExp
arr_w [(VName, SubExp)]
ispace [KernelInput]
inps [HistOp Kernels]
ops Lambda Kernels
lam [VName]
arrs = BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  VName
gtid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> BinderT Kernels (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> BinderT Kernels (State VNameSource) SegSpace)
-> [(VName, SubExp)]
-> BinderT Kernels (State VNameSource) SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
arr_w)]

  KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
 -> [KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult]
-> Stms Kernels
-> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ())) (BinderT Kernels (State VNameSource) ([KernelResult], Stms Kernels)
 -> BinderT Kernels (State VNameSource) (KernelBody Kernels))
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> BinderT
      Kernels (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
           Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
    [(Param (TypeBase Shape NoUniqueness), VName)]
-> ((Param (TypeBase Shape NoUniqueness), VName)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [VName] -> [(Param (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam) [VName]
arrs) (((Param (TypeBase Shape NoUniqueness), VName)
  -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness), VName)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, VName
arr) -> do
      TypeBase Shape NoUniqueness
arr_t <- VName
-> BinderT
     Kernels (State VNameSource) (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param (TypeBase Shape NoUniqueness) -> VName
forall attr. Param attr -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr (Slice SubExp -> BasicOp Kernels)
-> Slice SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
    (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult])
-> BinderT Kernels (State VNameSource) [SubExp]
-> Binder Kernels [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam)

  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp Kernels]
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [HistOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegHist SegLevel
lvl SegSpace
space [HistOp Kernels]
ops (Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
lam) KernelBody Kernels
kbody

blockedPerThread :: (MonadBinder m, Lore m ~ Kernels) =>
                    VName -> SubExp -> KernelSize -> StreamOrd -> Lambda Kernels
                 -> Int -> [VName]
                 -> m ([PatElem Kernels], [PatElem Kernels])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda Kernels
-> Int
-> [VName]
-> m ([PatElem Kernels], [PatElem Kernels])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda Kernels
lam Int
num_nonconcat [VName]
arrs = do
  let (VName
_, Param (TypeBase Shape NoUniqueness)
chunk_size, [], [Param (TypeBase Shape NoUniqueness)]
arr_params) =
        Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (VName, Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall attr.
Int
-> [Param attr] -> (VName, Param attr, [Param attr], [Param attr])
partitionChunkedKernelFoldParameters Int
0 ([Param (TypeBase Shape NoUniqueness)]
 -> (VName, Param (TypeBase Shape NoUniqueness),
     [Param (TypeBase Shape NoUniqueness)],
     [Param (TypeBase Shape NoUniqueness)]))
-> [Param (TypeBase Shape NoUniqueness)]
-> (VName, Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam

      ordering' :: SplitOrdering
ordering' =
        case StreamOrd
ordering of StreamOrd
InOrder -> SplitOrdering
SplitContiguous
                         StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
      red_ts :: [TypeBase Shape NoUniqueness]
red_ts = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
lam
      map_ts :: [TypeBase Shape NoUniqueness]
map_ts = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda Kernels
lam

  SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
  VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays (Param (TypeBase Shape NoUniqueness) -> VName
forall attr. Param attr -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_size) ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall attr. Param attr -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
arr_params) SplitOrdering
ordering' SubExp
w
    (VName -> SubExp
Var VName
thread_gtid) SubExp
per_thread [VName]
arrs

  [PatElemT (TypeBase Shape NoUniqueness)]
chunk_red_pes <- [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> m [PatElemT (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
red_ts ((TypeBase Shape NoUniqueness
  -> m (PatElemT (TypeBase Shape NoUniqueness)))
 -> m [PatElemT (TypeBase Shape NoUniqueness)])
-> (TypeBase Shape NoUniqueness
    -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> m [PatElemT (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
red_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
    PatElemT (TypeBase Shape NoUniqueness)
-> m (PatElemT (TypeBase Shape NoUniqueness))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (TypeBase Shape NoUniqueness)
 -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> PatElemT (TypeBase Shape NoUniqueness)
-> m (PatElemT (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
pe_name TypeBase Shape NoUniqueness
red_t
  [PatElemT (TypeBase Shape NoUniqueness)]
chunk_map_pes <- [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> m [PatElemT (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
map_ts ((TypeBase Shape NoUniqueness
  -> m (PatElemT (TypeBase Shape NoUniqueness)))
 -> m [PatElemT (TypeBase Shape NoUniqueness)])
-> (TypeBase Shape NoUniqueness
    -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> m [PatElemT (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
map_t -> do
    VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
    PatElemT (TypeBase Shape NoUniqueness)
-> m (PatElemT (TypeBase Shape NoUniqueness))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (TypeBase Shape NoUniqueness)
 -> m (PatElemT (TypeBase Shape NoUniqueness)))
-> PatElemT (TypeBase Shape NoUniqueness)
-> m (PatElemT (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
pe_name (TypeBase Shape NoUniqueness
 -> PatElemT (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness
map_t TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param (TypeBase Shape NoUniqueness) -> VName
forall attr. Param attr -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_size)

  let ([SubExp]
chunk_red_ses, [SubExp]
chunk_map_ses) =
        Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT Kernels -> [SubExp]) -> BodyT Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam

  Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    [ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (TypeBase Shape NoUniqueness)
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp Kernels
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
    | (PatElemT (TypeBase Shape NoUniqueness)
pe,SubExp
se) <- [PatElemT (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(PatElemT (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (TypeBase Shape NoUniqueness)]
chunk_red_pes [SubExp]
chunk_red_ses ] Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    [ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (TypeBase Shape NoUniqueness)
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp Kernels
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
    | (PatElemT (TypeBase Shape NoUniqueness)
pe,SubExp
se) <- [PatElemT (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(PatElemT (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (TypeBase Shape NoUniqueness)]
chunk_map_pes [SubExp]
chunk_map_ses ]

  ([PatElemT (TypeBase Shape NoUniqueness)],
 [PatElemT (TypeBase Shape NoUniqueness)])
-> m ([PatElemT (TypeBase Shape NoUniqueness)],
      [PatElemT (TypeBase Shape NoUniqueness)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (TypeBase Shape NoUniqueness)]
chunk_red_pes, [PatElemT (TypeBase Shape NoUniqueness)]
chunk_map_pes)

splitArrays :: (MonadBinder m, Lore m ~ Kernels) =>
               VName -> [VName]
            -> SplitOrdering -> SubExp -> SubExp -> SubExp -> [VName]
            -> m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
chunk_size] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
  case SplitOrdering
ordering of
    SplitOrdering
SplitContiguous     -> do
      SubExp
offset <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp Kernels
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
i SubExp
elems_per_i
      (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
    SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
  where contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
          TypeBase Shape NoUniqueness
arr_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
          let slice :: Slice SubExp
slice = TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))]
          [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr Slice SubExp
slice

        stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
          TypeBase Shape NoUniqueness
arr_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
          let slice :: Slice SubExp
slice = TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
          [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr Slice SubExp
slice

data KernelSize = KernelSize { KernelSize -> SubExp
kernelElementsPerThread :: SubExp
                               -- ^ Int64
                             , KernelSize -> SubExp
kernelNumThreads :: SubExp
                               -- ^ Int32
                             }
                deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> String -> String
[KernelSize] -> String -> String
KernelSize -> String
(Int -> KernelSize -> String -> String)
-> (KernelSize -> String)
-> ([KernelSize] -> String -> String)
-> Show KernelSize
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [KernelSize] -> String -> String
$cshowList :: [KernelSize] -> String -> String
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> String -> String
$cshowsPrec :: Int -> KernelSize -> String -> String
Show)

blockedKernelSize :: (MonadBinder m, Lore m ~ Kernels) =>
                     String -> SubExp -> m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
  SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  SubExp
w64 <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"w64" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp Kernels
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
w
  (SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size

  SubExp
per_thread_elements <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_thread_elements" (ExpT Kernels -> m SubExp) -> m (ExpT Kernels) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp IntType
Int64 (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64) (SubExp -> m (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> m (ExpT Kernels)) -> m SubExp -> m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
num_threads)

  KernelSize -> m KernelSize
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads

mapKernelSkeleton :: (HasScope Kernels m, MonadFreshNames m) =>
                     [(VName, SubExp)] -> [KernelInput]
                  -> m (SegSpace, Stms Kernels)
mapKernelSkeleton :: [(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms Kernels)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs = do
  Stms Kernels
read_input_bnds <- Binder Kernels [()] -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels [()] -> m (Stms Kernels))
-> Binder Kernels [()] -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> Binder Kernels [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs

  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
  (SegSpace, Stms Kernels) -> m (SegSpace, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegSpace
space, Stms Kernels
read_input_bnds)

mapKernel :: (HasScope Kernels m, MonadFreshNames m) =>
             MkSegLevel m
          -> [(VName, SubExp)] -> [KernelInput]
          -> [Type] -> KernelBody Kernels
          -> m (SegOp Kernels, Stms Kernels)
mapKernel :: MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> m (SegOp Kernels, Stms Kernels)
mapKernel MkSegLevel m
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
inputs [TypeBase Shape NoUniqueness]
rts (KernelBody () Stms Kernels
kstms [KernelResult]
krets) = BinderT Kernels m (SegOp Kernels)
-> m (SegOp Kernels, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (SegOp Kernels)
 -> m (SegOp Kernels, Stms Kernels))
-> BinderT Kernels m (SegOp Kernels)
-> m (SegOp Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
space, Stms Kernels
read_input_stms) <- [(VName, SubExp)]
-> [KernelInput] -> BinderT Kernels m (SegSpace, Stms Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms Kernels)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs

  let kbody' :: KernelBody Kernels
kbody' = BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels
read_input_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
kstms) [KernelResult]
krets

  -- If the kernel creates arrays (meaning it will require memory
  -- expansion), we want to truncate the amount of threads.
  -- Otherwise, have at it!  This is a bit of a hack - in principle,
  -- we should make this decision later, when we have a clearer idea
  -- of what is happening inside the kernel.
  let r :: ThreadRecommendation
r = if (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
rts then ThreadRecommendation
ManyThreads else SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegVirt

  SegLevel
lvl <- MkSegLevel m
mk_lvl (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segmap" ThreadRecommendation
r

  SegOp Kernels -> BinderT Kernels m (SegOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp Kernels -> BinderT Kernels m (SegOp Kernels))
-> SegOp Kernels -> BinderT Kernels m (SegOp Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody Kernels
-> SegOp Kernels
forall lore.
SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lore
SegMap SegLevel
lvl SegSpace
space [TypeBase Shape NoUniqueness]
rts KernelBody Kernels
kbody'

data KernelInput = KernelInput { KernelInput -> VName
kernelInputName :: VName
                               , KernelInput -> TypeBase Shape NoUniqueness
kernelInputType :: Type
                               , KernelInput -> VName
kernelInputArray :: VName
                               , KernelInput -> [SubExp]
kernelInputIndices :: [SubExp]
                               }
                 deriving (Int -> KernelInput -> String -> String
[KernelInput] -> String -> String
KernelInput -> String
(Int -> KernelInput -> String -> String)
-> (KernelInput -> String)
-> ([KernelInput] -> String -> String)
-> Show KernelInput
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [KernelInput] -> String -> String
$cshowList :: [KernelInput] -> String -> String
show :: KernelInput -> String
$cshow :: KernelInput -> String
showsPrec :: Int -> KernelInput -> String -> String
$cshowsPrec :: Int -> KernelInput -> String -> String
Show)

readKernelInput :: (MonadBinder m, Lore m ~ Kernels) =>
                   KernelInput -> m ()
readKernelInput :: KernelInput -> m ()
readKernelInput KernelInput
inp = do
  let pe :: PatElemT (TypeBase Shape NoUniqueness)
pe = VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall attr. VName -> attr -> PatElemT attr
PatElem (KernelInput -> VName
kernelInputName KernelInput
inp) (TypeBase Shape NoUniqueness
 -> PatElemT (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ KernelInput -> TypeBase Shape NoUniqueness
kernelInputType KernelInput
inp
  TypeBase Shape NoUniqueness
arr_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType (VName -> m (TypeBase Shape NoUniqueness))
-> VName -> m (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
  Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (TypeBase Shape NoUniqueness)
pe]) (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index (KernelInput -> VName
kernelInputArray KernelInput
inp) (Slice SubExp -> BasicOp Kernels)
-> Slice SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$
    TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp

injectSOACS :: (Monad m,
                SameScope from to,
                ExpAttr from ~ ExpAttr to,
                BodyAttr from ~ BodyAttr to,
                RetType from ~ RetType to,
                BranchType from ~ BranchType to,
                Op from ~ SOAC from) =>
               (SOAC to -> Op to) -> Rephraser m from to
injectSOACS :: (SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f = Rephraser :: forall (m :: * -> *) from to.
(ExpAttr from -> m (ExpAttr to))
-> (LetAttr from -> m (LetAttr to))
-> (FParamAttr from -> m (FParamAttr to))
-> (LParamAttr from -> m (LParamAttr to))
-> (BodyAttr from -> m (BodyAttr to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser { rephraseExpLore :: ExpAttr from -> m (ExpAttr to)
rephraseExpLore = ExpAttr from -> m (ExpAttr to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseBodyLore :: BodyAttr from -> m (BodyAttr to)
rephraseBodyLore = BodyAttr from -> m (BodyAttr to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseLetBoundLore :: LetAttr from -> m (LetAttr to)
rephraseLetBoundLore = LetAttr from -> m (LetAttr to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseFParamLore :: FParamAttr from -> m (FParamAttr to)
rephraseFParamLore = FParamAttr from -> m (FParamAttr to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseLParamLore :: LParamAttr from -> m (LParamAttr to)
rephraseLParamLore = LParamAttr from -> m (LParamAttr to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseOp :: Op from -> m (Op to)
rephraseOp = (SOAC to -> Op to) -> m (SOAC to) -> m (Op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SOAC to -> Op to
f (m (SOAC to) -> m (Op to))
-> (SOAC from -> m (SOAC to)) -> SOAC from -> m (Op to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC from -> m (SOAC to)
onSOAC
                          , rephraseRetType :: RetType from -> m (RetType to)
rephraseRetType = RetType from -> m (RetType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseBranchType :: BranchType from -> m (BranchType to)
rephraseBranchType = BranchType from -> m (BranchType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          }
  where onSOAC :: SOAC from -> m (SOAC to)
onSOAC = SOACMapper from to m -> SOAC from -> m (SOAC to)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM SOACMapper from to m
mapper
        mapper :: SOACMapper from to m
mapper = SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp
SOAC.mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                                 , mapOnSOACVName :: VName -> m VName
SOAC.mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                                 , mapOnSOACLambda :: Lambda from -> m (Lambda to)
SOAC.mapOnSOACLambda = Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda (Rephraser m from to -> Lambda from -> m (Lambda to))
-> Rephraser m from to -> Lambda from -> m (Lambda to)
forall a b. (a -> b) -> a -> b
$ (SOAC to -> Op to) -> Rephraser m from to
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpAttr from ~ ExpAttr to,
 BodyAttr from ~ BodyAttr to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f
                                 }

soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels = Identity (Stm Kernels) -> Stm Kernels
forall a. Identity a -> a
runIdentity (Identity (Stm Kernels) -> Stm Kernels)
-> (Stm SOACS -> Identity (Stm Kernels))
-> Stm SOACS
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Stm SOACS -> Identity (Stm Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpAttr from ~ ExpAttr to,
 BodyAttr from ~ BodyAttr to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels = Identity (Lambda Kernels) -> Lambda Kernels
forall a. Identity a -> a
runIdentity (Identity (Lambda Kernels) -> Lambda Kernels)
-> (Lambda SOACS -> Identity (Lambda Kernels))
-> Lambda SOACS
-> Lambda Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Lambda SOACS -> Identity (Lambda Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpAttr from ~ ExpAttr to,
 BodyAttr from ~ BodyAttr to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs = Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope

scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels = Scope SOACS -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope