{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where
import Control.Monad.Except
import Data.List (zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
import Prelude hiding (mod, quot, rem)
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
alignTo :: IntegralExp a => a -> a -> a
alignTo :: a -> a -> a
alignTo a
x a
a = (a
x a -> a -> a
forall e. IntegralExp e => e -> e -> e
`divUp` a
a) a -> a -> a
forall a. Num a => a -> a -> a
* a
a
createLocalArrays ::
Count GroupSize SubExp ->
SubExp ->
[PrimType] ->
InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
m [PrimType]
types = do
let groupSizeE :: TExp Int64
groupSizeE = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
groupSize
workSize :: TExp Int64
workSize = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE
prefixArraysSize :: TExp Int64
prefixArraysSize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types
warpSize :: Num a => a
warpSize :: a
warpSize = a
32
maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warpSize) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warpSize
size :: Count Bytes (TExp Int64)
size = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize
varTE :: TV Int64 -> TPrimExp Int64 VName
varTE :: TV Int64 -> TPrimExp Int64 VName
varTE = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (TV Int64 -> VName) -> TV Int64 -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> VName
forall t. TV t -> VName
tvVar
[TPrimExp Int64 VName]
byteOffsets <-
(TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"byte_offsets") ([TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
groupSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize) TExp Int64
0 ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
[TPrimExp Int64 VName]
warpByteOffsets <-
(TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"warp_byte_offset") ([TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize) TExp Int64
forall a. Num a => a
warpSize ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Allocate reused shared memeory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
VName
localMem <- String
-> Count Bytes (TExp Int64)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TExp Int64)
size (String -> Space
Space String
"local")
TV Int64
transposeArrayLength <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trans_arr_len" TExp Int64
workSize
VName
sharedId <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_id" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
VName
sharedReadOffset <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_read_offset" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
[VName]
transposedArrays <-
[PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
String
"local_transpose_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
VName
localMem
[VName]
prefixArrays <-
[(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
byteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op VName
sArray
String
"local_prefix_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
(MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize]
VName
warpscan <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"warpscan" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)]) VName
localMem
[VName]
warpExchanges <-
[(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
warpByteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op VName
sArray
String
"warp_exchange"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)])
(MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [TPrimExp Int64 VName
forall a. Num a => a
warpSize]
(VName, [VName], [VName], VName, VName, [VName])
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
warpExchanges)
compileSegScan ::
Pat GPUMem ->
SegLevel ->
SegSpace ->
SegBinOp GPUMem ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat GPUMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scanOp KernelBody GPUMem
kbody = do
let Pat [PatElemT LParamMem]
all_pes = Pat GPUMem
PatT LParamMem
pat
group_size :: Count GroupSize (TExp Int64)
group_size = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
n :: TExp Int64
n = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
num_groups :: Count NumGroups (TExp Int64)
num_groups = TExp Int64 -> Count NumGroups (TExp Int64)
forall u e. e -> Count u e
Count (TExp Int64
n TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m))
num_threads :: TExp Int64
num_threads = Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size
([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
segmented :: Bool
segmented = [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
not_segmented_e :: TPrimExp Bool ExpLeaf
not_segmented_e = if Bool
segmented then TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
false else TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
true
segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
scanOpNe :: [SubExp]
scanOpNe = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scanOp
tys :: [PrimType]
tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map (\(Prim PrimType
pt) -> PrimType
pt) ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
LambdaT rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp
statusX, statusA, statusP :: Num a => a
statusX :: a
statusX = a
0
statusA :: a
statusA = a
1
statusP :: a
statusP = a
2
makeStatusUsed :: TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV t
flag TV t
used = TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
flag TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.|. (TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
used TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.<<. TPrimExp t ExpLeaf
2)
unmakeStatusUsed :: TV Int8 -> TV Int8 -> TV Int8 -> InKernelGen ()
unmakeStatusUsed :: TV Int8 -> TV Int8 -> TV Int8 -> ImpM GPUMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flagUsed TV Int8
flag TV Int8
used = do
TV Int8
used TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TExp Int8 -> TExp Int8 -> TExp Int8
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TExp Int8
2
TV Int8
flag TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TExp Int8 -> TExp Int8 -> TExp Int8
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. TExp Int8
3
sumT :: Integer
maxT :: Integer
sumT :: Integer
sumT = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize PrimType
typ) Integer
0 [PrimType]
tys
primByteSize' :: PrimType -> Integer
primByteSize' = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
4 (Integer -> Integer)
-> (PrimType -> Integer) -> PrimType -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize
sumT' :: Integer
sumT' = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
primByteSize' PrimType
typ) Integer
0 [PrimType]
tys Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
4
maxT :: Integer
maxT = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((PrimType -> Integer) -> [PrimType] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize [PrimType]
tys)
k_reg :: Integer
k_reg = Integer
64
k_mem :: Integer
k_mem = Integer
48
mem_constraint :: Integer
mem_constraint = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
k_mem Integer
sumT Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
maxT
reg_constraint :: Integer
reg_constraint = (Integer
k_reg Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
sumT') Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sumT' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
3)
m :: Num a => a
m :: a
m = Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> a) -> Integer -> a
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
1 (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
mem_constraint Integer
reg_constraint
TV Int64
numThreads <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"numThreads" TExp Int64
num_threads
TV Int64
numGroups <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"numGroups" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups
VName
globalId <- String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"id_counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Int -> ArrayContents
Imp.ArrayZeros Int
1
VName
statusFlags <- String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"status_flags" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
([VName]
aggregateArrays, [VName]
incprefixArrays) <-
([(VName, VName)] -> ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip (ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
[PrimType]
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)])
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
(,) (VName -> VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusX
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
(VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
exchanges) <-
Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m) [PrimType]
tys
TV Int64
dynamicId <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
(VName
globalIdMem, Space
_, Count Elements (TExp Int64)
globalIdOff) <- VName
-> [TExp Int64]
-> ImpM
GPUMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
globalId [TExp Int64
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd
IntType
Int32
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId)
VName
globalIdMem
(TExp Int64 -> Count Elements (TExp Int64)
forall u e. e -> Count u e
Count (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count Elements (TExp Int64)
globalIdOff)
(TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32))
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
sharedId [TExp Int64
0] (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
dynamicId) []
let localBarrier :: KernelOp
localBarrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
localFence :: KernelOp
localFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
globalFence :: KernelOp
globalFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId) [] (VName -> SubExp
Var VName
sharedId) [TExp Int64
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
TV Int64
blockOff <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"blockOff" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
TExp Int64
sgmIdx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sgm_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`mod` TExp Int64
segment_size
TExp Int32
boundary <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"boundary" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size) (TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
sgmIdx)
TExp Int32
segsize_compact <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segsize_compact" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size) TExp Int64
segment_size
[VName]
privateArrays <-
[PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
String
"private"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m])
([SubExp] -> PrimType -> Space
ScalarSpace [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m] PrimType
ty)
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Load and map" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
phys_tid <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"phys_tid" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
[(VName, TExp Int64)]
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
phys_tid
let in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
in_bounds =
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scanOp]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
[(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElemT LParamMem]
all_pes) [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
dest, KernelResult
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
dest) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []
out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds =
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [SubExp]
scanOpNe) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int64
phys_tid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) ImpM GPUMem KernelEnv KernelOp ()
in_bounds ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan inputs" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
sharedIdx <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
trans [TExp Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TExp Int32
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
forall a. Num a => a
m ((TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TV Int32
sharedIdx <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
priv [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i] (VName -> SubExp
Var VName
trans) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Per thread scan" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TExp Int32
globalIdx <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
-TExp Int64
1) ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
let xs :: [VName]
xs = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scanOp
ys :: [VName]
ys = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scanOp
TPrimExp Bool ExpLeaf
new_sgm <-
if Bool
segmented
then String
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"new_sgm" (TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf))
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ (TExp Int32
globalIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
else TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
false
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sUnless TPrimExp Bool ExpLeaf
new_sgm (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, VName, PrimType))]
-> ((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([(VName, VName, PrimType)] -> [(VName, (VName, VName, PrimType))])
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [PrimType] -> [(VName, VName, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [VName]
ys [PrimType]
tys) (((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, (VName
x, VName
y, PrimType
ty)) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1]
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (LambdaT GPUMem -> BodyT GPUMem) -> LambdaT GPUMem -> BodyT GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (LambdaT GPUMem -> BodyT GPUMem) -> LambdaT GPUMem -> BodyT GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Publish results in shared memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (VName -> SubExp
Var VName
src) [TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
let crossesSegment :: Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
crossesSegment = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
(TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf))
-> (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
let from' :: TExp Int32
from' = (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
to' :: TExp Int32
to' = (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
in (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
from') TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact
LambdaT GPUMem
scanOp' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem))
-> LambdaT GPUMem
-> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp
[TV Any]
accs <- (PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> [PrimType] -> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"acc") [PrimType]
tys
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Scan results (with warp scan)" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> TExp Int64
-> TExp Int64
-> LambdaT GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
crossesSegment
(TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
numThreads)
(KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
LambdaT GPUMem
scanOp'
[VName]
prefixArrays
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
let firstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
notFirstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
[TV Any]
prefixes <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TPrimExp Bool ExpLeaf
blockNewSgm <- String
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_new_sgm" (TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf))
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TExp Int64
sgmIdx TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Perform lookback" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf
blockNewSgm TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
[(SubExp, TV Any)]
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [TV Any] -> [(SubExp, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [TV Any]
accs) (((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
let warpSize :: TExp Int32
warpSize = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool ExpLeaf
blockNewSgm TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
warpSize) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Bool ExpLeaf
not_segmented_e TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Int32
boundary TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m))
( do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusA) []
)
( do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
)
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int64
0] (VName -> SubExp
Var VName
statusFlags) [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
TV Int8
status <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TExp Int64
0]
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
status TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
( TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
)
( do
TV Int32
readOffset <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"readOffset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants)
let loopStop :: TExp Int32
loopStop = TExp Int32
warpSize TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (-TExp Int32
1)
sameSegment :: TV Int32 -> TPrimExp Bool ExpLeaf
sameSegment TV Int32
readIdx
| Bool
segmented =
let startIdx :: TExp Int64
startIdx = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
in TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
startIdx TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
sgmIdx
| Bool
otherwise = TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
true
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
loopStop) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
readI <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"read_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
[TV Any]
aggrs <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TV Int8
flag <- String -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag" TExp Int8
forall a. Num a => a
statusX
TV Int8
used <- String -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"used" TExp Int8
0
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int32 -> TPrimExp Bool ExpLeaf
sameSegment TV Int32
readI)
( do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
( [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
)
( TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusA) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
TV Int8
used TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int8
1
)
)
(VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) [])
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
aggr) []
TV Int8
tmp <- String -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"tmp" (TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8))
-> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TV Int8 -> TExp Int8
forall t. NumExp t => TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV Int8
flag TV Int8
used
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Int8 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int8
tmp) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
(VName
warpscanMem, Space
warpscanSpace, Count Elements (TExp Int64)
warpscanOff) <-
VName
-> [TExp Int64]
-> ImpM
GPUMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
warpscan [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
TV Int8
flag TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Exp -> TExp Int8
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
warpscanMem Count Elements (TExp Int64)
warpscanOff PrimType
int8 Space
warpscanSpace Volatility
Imp.Volatile)
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
( do
LambdaT GPUMem
scanOp'' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
let ([VName]
agg1s, [VName]
agg2s) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''
[(VName, SubExp, PrimType)]
-> ((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [PrimType] -> [(VName, SubExp, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [SubExp]
scanOpNe [PrimType]
tys) (((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, SubExp
ne, PrimType
ty) ->
VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
agg1 (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
(VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName] -> [PrimType] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ [VName]
agg2s [PrimType]
tys
TV Int8
flag1 <- String -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag1" TExp Int8
forall a. Num a => a
statusX
TV Int8
flag2 <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flag2" PrimType
int8
TV Int8
used1 <- String -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"used1" TExp Int8
0
TV Int8
used2 <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"used2" PrimType
int8
String
-> TExp Int32
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
warpSize ((TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag2) [] (VName -> SubExp
Var VName
warpscan) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
TV Int8 -> TV Int8 -> TV Int8 -> ImpM GPUMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flag2 TV Int8
flag2 TV Int8
used2
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
agg2s [VName]
exchanges) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg2, VName
exchange) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
agg2 [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag2 TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusA)
( do
TV Int8
flag1 TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag2
TV Int8
used1 TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used2
[(VName, PrimType, VName)]
-> ((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [VName] -> [(VName, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys [VName]
agg2s) (((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, PrimType
ty, VName
agg2) ->
VName
agg1 VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg2)
)
( do
TV Int8
used1 TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used1 TExp Int8 -> TExp Int8 -> TExp Int8
forall a. Num a => a -> a -> a
+ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used2
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'') (((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
agg1, PrimType
ty, SubExp
res) -> VName
agg1 VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
)
TV Int8
flag TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag1
TV Int8
used TV Int8 -> TExp Int8 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used1
[(TV Any, PrimType, VName)]
-> ((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [VName] -> [(TV Any, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
aggrs [PrimType]
tys [VName]
agg1s) (((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, PrimType
ty, VName
agg1) ->
TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg1)
)
( [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
)
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
(TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
loopStop)
(TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int8 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used))
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
sharedReadOffset [TExp Int64
0] (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
readOffset) []
LambdaT GPUMem
scanOp''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp'''
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
aggr)
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix)
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys ([SubExp] -> [(TV Any, PrimType, SubExp)])
-> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''') (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix TV Any -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res)
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
readOffset) [] (VName -> SubExp
Var VName
sharedReadOffset) [TExp Int64
0]
)
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
LambdaT GPUMem
scanOp'''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
let xs :: [VName]
xs = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''
ys :: [VName]
ys = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
boundary TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
incprefixArray, SubExp
res) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] SubExp
res []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
prefix) []
[(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanOpNe) (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
LambdaT GPUMem
scanOp''''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
LambdaT GPUMem
scanOp'''''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Distribute results" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp'''''
([VName]
xs', [VName]
ys') = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''''
[((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(TV Any, TV Any)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [PrimType]
-> [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ([TV Any] -> [TV Any] -> [(TV Any, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [TV Any]
accs) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [VName]
xs') ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
ys') [PrimType]
tys) ((((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\((TV Any
prefix, TV Any
acc), (VName
x, VName
x'), (VName
y, VName
y'), PrimType
ty) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
boundary TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool ExpLeaf
blockNewSgm)
( Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''''') (((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
)
([(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
acc) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [])
TExp Int32
stop <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"stopping_point" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
segsize_compact
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [VName]
ys) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (BodyT GPUMem -> [SubExpRes]) -> BodyT GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
dest, SubExp
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan output" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TV Int64
sharedIdx <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
trans [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TV Int32
sharedIdx <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i)
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
priv [TExp Int64
i] (VName -> SubExp
Var VName
trans) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Write block scan results to global memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
flat_idx <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
[(VName, TExp Int64)]
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
flat_idx
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
flat_idx TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElemT LParamMem -> VName) -> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT LParamMem]
all_pes) [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) (VName -> SubExp
Var VName
src) [TExp Int64
i]
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"If this is the last block, reset the dynamicId" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
globalId [TExp Int64
0] (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0 :: Int32)) []