{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.Group
( sKernelGroup,
compileGroupResult,
groupOperations,
Precomputed,
precomputeConstants,
precomputedConstants,
atomicUpdateLocking,
)
where
import Control.Monad.Except
import Data.Bifunctor
import Data.List (partition, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (divUp, rem)
import Prelude hiding (quot, rem)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
ArrayEntry MemLoc
arr_loc PrimType
pt <- forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
let flat_shape :: ShapeBase SubExp
flat_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
flat) forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_flat") PrimType
pt ShapeBase SubExp
flat_shape (MemLoc -> VName
memLocName MemLoc
arr_loc) forall a b. (a -> b) -> a -> b
$
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape (MemLoc -> IxFun
memLocIxFun MemLoc
arr_loc) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
flat_shape
sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
MemLoc VName
mem [SubExp]
_ IxFun
ixfun <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
TypeBase (ShapeBase SubExp) NoUniqueness
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
let slice :: Slice (TPrimExp Int64 VName)
slice =
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
(forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
arr_t))
[forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
(VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_chunk")
(forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
arr_t)
(forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
arr_t forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
size))
VName
mem
forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun Slice (TPrimExp Int64 VName)
slice
applyLambda ::
Mem rep inner =>
Lambda rep ->
[(VName, [DimIndex (Imp.TExp Int64)])] ->
[(SubExp, [DimIndex (Imp.TExp Int64)])] ->
ImpM rep r op ()
applyLambda :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
let res :: [SubExp]
res = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []
applyRenamedLambda ::
Mem rep inner =>
Lambda rep ->
[(VName, [DimIndex (Imp.TExp Int64)])] ->
[(SubExp, [DimIndex (Imp.TExp Int64)])] ->
ImpM rep r op ()
applyRenamedLambda :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
Lambda rep
lam_renamed <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam_renamed [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args
groupChunkLoop ::
Imp.TExp Int32 ->
(Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
InKernelGen ()
groupChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let max_chunk_size :: TExp Int32
max_chunk_size = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
TExp Int32
num_chunks <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ TExp Int32
w forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
max_chunk_size
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
TExp Int32
chunk_start <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_start" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
TExp Int32
chunk_end <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_end" forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMin32 TExp Int32
w (TExp Int32
chunk_start forall a. Num a => a -> a -> a
+ TExp Int32
max_chunk_size)
TV Int64
chunk_size <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"chunk_size" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_end forall a. Num a => a -> a -> a
- TExp Int32
chunk_start
TExp Int32 -> TV Int64 -> InKernelGen ()
m TExp Int32
chunk_start TV Int64
chunk_size
virtualisedGroupScan ::
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Imp.TExp Int32 ->
Lambda GPUMem ->
[VName] ->
InKernelGen ()
virtualisedGroupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
crosses_segment :: TExp Bool
crosses_segment =
case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> forall v. TPrimExp Bool v
false
Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TExp Int32
1)) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment) forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int64 VName
carry_idx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"carry_idx" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
Lambda GPUMem
lam
(forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
( forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
carry_idx])
forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
)
[VName]
arrs_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
arrs
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
w) (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size) Lambda GPUMem
lam [VName]
arrs_chunks
copyInGroup :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLoc
destloc MemLoc
srcloc = do
Space
dest_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
Space
src_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
srcloc)
let src_ixfun :: IxFun
src_ixfun = MemLoc -> IxFun
memLocIxFun MemLoc
srcloc
dims :: [TPrimExp Int64 VName]
dims = forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
src_ixfun
rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims
case (Space
dest_space, Space
src_space) of
(ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
let fullDim :: d -> DimIndex d
fullDim d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise
PrimType
pt
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
(Space, Space)
_ -> do
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise
PrimType
pt
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
TPrimExp Int64 VName
ltid <- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid" [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
ltid) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId SegSpace
space = do
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid
prepareIntraGroupSegHist ::
Count GroupSize SubExp ->
[HistOp GPUMem] ->
InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp forall a. Maybe a
Nothing
where
onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let local_subhistos :: [VName]
local_subhistos = forall {k} (rep :: k). HistOp rep -> [VName]
histDest HistOp GPUMem
op
case (Maybe Locking
l, AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
VName
locks <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"locks"
let num_locks :: TPrimExp Int64 VName
num_locks = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp GPUMem
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp GPUMem
op)
l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
locks_t :: TypeBase (ShapeBase SubExp) NoUniqueness
locks_t = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness
VName
locks_mem <- forall {k} (rep :: k) r op.
[Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc [Char]
"locks_mem" (TypeBase (ShapeBase SubExp) NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
typeSize TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
forall {k} (rep :: k) r op.
VName
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op ()
dArray VName
locks PrimType
int32 (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) VName
locks_mem forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$
TypeBase (ShapeBase SubExp) NoUniqueness
locks_t
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants] forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
let virt' :: SegVirt
virt' = if [TPrimExp Int64 VName]
dims' forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
group_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
case SegVirt
virt' of
SegVirt
SegVirt -> do
Maybe (TExp Int32)
iters <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
case Maybe (TExp Int32)
iters of
Maybe (TExp Int32)
Nothing -> do
TExp Int32
iterations <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int32
iterations forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
InKernelGen ()
m
Just TExp Int32
num_chunks -> do
let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
TExp Int32
i <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
+ TExp Int32
ltid
forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> TPrimExp Int64 a
le64) [VName]
ltids)) [TPrimExp Int64 VName]
dims') InKernelGen ()
m
SegVirt
SegNoVirt -> forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
SegNoVirtFull SegSeqDims
seq_dims -> do
let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
InKernelGen ()
m
compileGroupExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) [] SubExp
se []
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) =
VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) = do
VName
flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_flat"
[VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
ds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_i")
let is' :: [TPrimExp Int64 VName]
is' = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
SegVirt (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
ds) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName]
is' SubExp
se []
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Rotate [SubExp]
rs VName
arr)) = do
[TPrimExp Int64 VName]
ds <- forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TPrimExp Int64 VName]
ds forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
[TPrimExp Int64 VName]
is' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 forall {k2} {rep :: k2} {r} {op}.
TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate [TPrimExp Int64 VName]
ds [SubExp]
rs [TPrimExp Int64 VName]
is
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName]
is (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
is'
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
where
rotate :: TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate TPrimExp Int64 VName
d SubExp
r TPrimExp Int64 VName
i = forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rot_i" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
rotateIndex TPrimExp Int64 VName
d (SubExp -> TPrimExp Int64 VName
pe64 SubExp
r) TPrimExp Int64 VName
i
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
Exp
n' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
n
Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
Exp
s' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp Exp
n') forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i' -> do
TV Any
x <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' forall a b. (a -> b) -> a -> b
$
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') Exp
s'
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName
i'] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
| forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
case Safety
safety of
Safety
Unsafe -> InKernelGen ()
write
Safety
Safe -> forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
where
slice' :: Slice (TPrimExp Int64 VName)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
write :: InKernelGen ()
write = forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) (forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileGroupExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e
compileGroupOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
pat SubExp
size Space
space
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
VName
dest
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence
let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
let scan :: SegBinOp GPUMem
scan = forall a. [a] -> a
head [SegBinOp GPUMem]
scans
num_scan_results :: Int
num_scan_results = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
[VName]
arrs_flat <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) forall a b. (a -> b) -> a -> b
$
forall a. Int -> [a] -> [a]
take Int
num_scan_results forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat
case SegLevel -> SegVirt
segVirt SegLevel
lvl of
SegVirt
SegVirt ->
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
(forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
(forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
[VName]
arrs_flat
SegVirt
_ ->
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
(forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
(forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
[VName]
arrs_flat
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
forall {k} (rep :: k) r op.
[Char]
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray [Char]
"red_arr" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
[VName]
tmp_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k).
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let tmps_for_ops :: [[VName]]
tmps_for_ops = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
case SegLevel -> SegVirt
segVirt SegLevel
lvl of
SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
where
([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([PatElem LetDecMem]
red_pes, [PatElem LetDecMem]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat
virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
(forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
(forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmps forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
( forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem LetDecMem]
red_pes) (forall a. a -> [a]
repeat [])
forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
tmps) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
)
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
[VName]
tmps_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size)) (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps_chunks
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
[VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
(forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
(forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
[VName]
tmps_flat
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
(forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
[]
(VName -> SubExp
Var VName
arr)
(forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
[VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
(forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
(forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
[VName]
tmps_flat
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
(forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
[]
(VName -> SubExp
Var VName
arr)
(forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let ([VName]
ltids, [SubExp]
_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
let num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
([PatElem LetDecMem]
_red_pes, [PatElem LetDecMem]
map_pes) =
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat
Count GroupSize SubExp
group_size <- KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
[[TPrimExp Int64 VName] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size [HistOp GPUMem]
ops
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
([SubExp]
red_is, [SubExp]
red_vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res
let vs_per_op :: [[SubExp]]
vs_per_op = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$
\(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda GPUMem
lam) -> do
let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName
bin'])) [TPrimExp Int64 VName]
dest_shape'
bin_is :: [TPrimExp Int64 VName]
bin_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
ltids) forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
vs_params :: [Param LetDecMem]
vs_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
[TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$ [Char]
"compileGroupOp: cannot compile rhs of binding " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
pat
groupOperations :: Operations GPUMem KernelEnv Imp.KernelOp
groupOperations :: Operations GPUMem KernelEnv KernelOp
groupOperations =
(forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp)
{ opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup,
opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp,
opsStmsCompiler :: StmsCompiler GPUMem KernelEnv KernelOp
opsStmsCompiler = \Names
_ -> forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms forall a. Monoid a => a
mempty,
opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char] -> Space
Space [Char]
"local", forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
}
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
VarEntry GPUMem
res <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case VarEntry GPUMem
res of
ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
([Char] -> Space
Space [Char]
"local" ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
VarEntry GPUMem
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInLocalMemory Constant {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
sKernelGroup ::
String ->
VName ->
KernelAttrs ->
InKernelGen () ->
CallKernelGen ()
sKernelGroup :: [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId
compileGroupResult ::
SegSpace ->
PatElem LetDecMem ->
KernelResult ->
InKernelGen ()
compileGroupResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
offset :: TPrimExp Int64 VName
offset =
SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems
forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
if SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems forall a. Eq a => a -> a -> Bool
== KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
then
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
ltid]
else forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
j <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
j]
compileGroupResult SegSpace
space PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
group_is :: [TPrimExp Int64 VName]
group_is = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
[TPrimExp Int64 VName]
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
[TV Int64]
is_for_thread <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"thread_out_index") forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [TPrimExp Int64 VName]
group_is [TPrimExp Int64 VName]
local_is
forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> VName
tvVar [TV Int64]
is_for_thread) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
local_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
group_tiles' :: [TPrimExp Int64 VName]
group_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
group_tiles
reg_tiles' :: [TPrimExp Int64 VName]
reg_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles
let group_tile_is :: [TPrimExp Int64 VName]
group_tile_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids
[TPrimExp Int64 VName]
reg_tile_is <-
forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"reg_tile_i" [TPrimExp Int64 VName]
group_tiles' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
TExp t
tile_dim_start <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"tile_dim_start" forall a b. (a -> b) -> a -> b
$
TExp t
reg_tile forall a. Num a => a -> a -> a
* (TExp t
group_tile forall a. Num a => a -> a -> a
* TExp t
group_tile_i forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
Slice (TPrimExp Int64 VName)
reg_tile_slices <-
forall d. [DimIndex d] -> Slice d
Slice
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
forall {k1} {k2} {t :: k1} {rep :: k2} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim
(forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
group_tiles' [TPrimExp Int64 VName]
group_tile_is)
(forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
reg_tiles' [TPrimExp Int64 VName]
reg_tile_is)
forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
let dest_is :: [TPrimExp Int64 VName]
dest_is = forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
let gids :: [TPrimExp Int64 VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
Imp.le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
if Bool -> Bool
not Bool
in_local_memory
then
forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
else
forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElem LetDecMem
_ WriteReturns {} =
forall a. [Char] -> a
compilerLimitationS [Char]
"compileGroupResult: WriteReturns not handled yet."
type SegOpSizes = S.Set [SubExp]
data Precomputed = Precomputed
{ Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
}
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
where
onStms :: Stms GPUMem -> SegOpSizes
onStms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp GPUMem -> SegOpSizes
onExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp)
onExp :: Exp GPUMem -> SegOpSizes
onExp (Op (Inner (SegOp SegOp SegLevel GPUMem
op))) =
case SegLevel -> SegVirt
segVirt forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
SegNoVirtFull SegSeqDims
seq_dims ->
forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
SegVirt
_ -> forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
onExp (BasicOp (Replicate ShapeBase SubExp
shape SubExp
_)) =
forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
onExp (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_) =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
onExp (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body) =
Stms GPUMem -> SegOpSizes
onStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body)
onExp Exp GPUMem
_ = forall a. Monoid a => a
mempty
precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size Stms GPUMem
stms = do
let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
Map [SubExp] (TExp Int32)
iters_map <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegOpSizes -> Map [SubExp] (TExp Int32) -> Precomputed
Precomputed SegOpSizes
sizes Map [SubExp] (TExp Int32)
iters_map
where
mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
let n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
TExp Int32
num_chunks <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, TExp Int32
num_chunks)
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
Map [SubExp] [TExp Int32]
new_ids <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} {t :: k} {rep :: k} {r} {op}.
IntExp t =>
TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TExp Int32
ltid) (forall a. Set a -> [a]
S.toList (Precomputed -> SegOpSizes
pcSegOpSizes Precomputed
pre))
let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
KernelEnv
env
{ kernelConstants :: KernelConstants
kernelConstants =
(KernelEnv -> KernelConstants
kernelConstants KernelEnv
env)
{ kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids,
kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap Precomputed
pre
}
}
forall {k} r (rep :: k) op a.
(r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
where
mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
[TPrimExp Int64 VName]
ids' <- forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid_pre" [TPrimExp Int64 VName]
dims' (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
ids')