{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.FoldSeg
where
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold ( reduceBlockSMem, reduceWarpSMem, imapFromTo )
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.Representation
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Monad ( void )
import Control.Monad.State ( gets )
import Data.String ( fromString )
import Prelude as P
mkFoldSeg
:: forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSeg :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSeg Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg =
IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) (IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen
PTX
(IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg
CodeGen
PTX
(IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg
mkFoldSegP_block
:: forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr MIRDelayed PTX aenv (Segments i)
mseg = do
DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
let
(IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut) = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
(IRDelayed PTX aenv (Array (sh, Int) e)
arrIn, [Parameter]
paramIn) = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed PTX aenv (Array (sh, Int) e)
marr
(IRDelayed PTX aenv (Segments i)
arrSeg, [Parameter]
paramSeg) = Name (Segments i)
-> MIRDelayed PTX aenv (Segments i)
-> (IRDelayed PTX aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Segments i)
"seg" MIRDelayed PTX aenv (Segments i)
mseg
paramEnv :: [Parameter]
paramEnv = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
config :: LaunchConfig
config = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.decWarp DeviceProperties
dev) Int -> Int
dsmem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
dsmem :: Int -> Int
dsmem Int
n = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
where
ws :: Int
ws = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
warps :: Int
warps = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
per_warp :: Int
per_warp = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
bytes :: Int
bytes = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldSeg_block" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do
IRArray (Vector Int)
smem <- TypeR Int -> Word64 -> CodeGen PTX (IRArray (Vector Int))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Word64
2
Operands Int
sz <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn
Operands Int
ss <- do Operands Int
n <- Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands ((), Int) -> Operands Int)
-> CodeGen PTX (Operands ((), Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Segments i) -> CodeGen PTX (Operands ((), Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Segments i)
arrSeg
NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
n (Int -> Operands Int
liftInt Int
1)
Operands Int
start <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
Operands Int
end <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut)
Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
s -> do
Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
2)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
Operands Int
i <- case ShapeR sh
shr of
ShapeRsnoc ShapeR sh1
ShapeRz -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
s
ShapeR sh
_ -> IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.rem IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
Operands Int
j <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
Operands i
v <- IROpenFun1 PTX () aenv (Int -> i)
-> Operands Int -> IROpenExp PTX ((), Int) aenv i
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Segments i)
-> IROpenFun1 PTX () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Segments i)
arrSeg) Operands Int
j
IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> Operands Int
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem Operands Int32
tid (Operands Int -> CodeGen PTX ())
-> CodeGen PTX (Operands Int) -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
intTp NumType Int
forall a. IsNum a => NumType a
numType Operands i
v
CodeGen PTX ()
__syncthreads
Operands Int
u <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem (Int32 -> Operands Int32
liftInt32 Int32
0)
Operands Int
v <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem (Int32 -> Operands Int32
liftInt32 Int32
1)
(Operands Int
inf,Operands Int
sup) <- Operands (Int, Int) -> (Operands Int, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (Int, Int) -> (Operands Int, Operands Int))
-> CodeGen PTX (Operands (Int, Int))
-> CodeGen PTX (Operands Int, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ShapeR sh
shr of
ShapeRsnoc ShapeR sh1
ShapeRz -> Operands (Int, Int) -> CodeGen PTX (Operands (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair Operands Int
u Operands Int
v)
ShapeR sh
_ -> do Operands Int
q <- IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
q Operands Int
sz
Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair (Operands Int -> Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands Int -> Operands (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
u Operands Int
a
CodeGen PTX (Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands (Int, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
v Operands Int
a
CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
if (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
inf Operands Int
sup)
then do
case MIRExp PTX aenv e
mseed of
MIRExp PTX aenv e
Nothing -> Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
Just IRExp PTX aenv e
z -> do
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
else do
Operands Int
i0 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
Operands e
x0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i0 Operands Int
sup)
then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
else let
go :: TypeR a -> Operands a
go :: TypeR a -> Operands a
go TypeR a
TupRunit = Operands a
Operands ()
OP_Unit
go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
in
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp
Operands Int
bd <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
Operands Int
v0 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
inf
Operands Int32
v0' <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v0
Operands e
r0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v0 Operands Int
bd)
then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
x0
else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v0') Operands e
x0
Operands Int
nxt <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
bd
Operands e
r <- TypeR e
-> Operands Int
-> Operands Int
-> Operands Int
-> Operands e
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp Operands Int
nxt Operands Int
bd Operands Int
sup Operands e
r0 ((Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e))
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset Operands e
r -> do
CodeGen PTX ()
__syncthreads
Operands Int
i' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
Operands Int
v' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
offset
Operands e
r' <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v' Operands Int
bd)
then do
Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
Operands e
x
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y
else do
Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i' Operands Int
sup)
then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
else let
go :: TypeR a -> Operands a
go :: TypeR a -> Operands a
go TypeR a
TupRunit = Operands a
Operands ()
OP_Unit
go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
in
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp
Operands Int32
z <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v'
Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
z) Operands e
x
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y
if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
r Operands e
Operands e
r'
else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r'
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
case MIRExp PTX aenv e
mseed of
MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
Just IRExp PTX aenv e
z -> (Operands e -> Operands e -> IRExp PTX aenv e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
Operands e
r (Operands e -> IRExp PTX aenv e)
-> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_
mkFoldSegP_warp
:: forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr MIRDelayed PTX aenv (Segments i)
mseg = do
DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
let
(IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut) = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
(IRDelayed PTX aenv (Array (sh, Int) e)
arrIn, [Parameter]
paramIn) = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed PTX aenv (Array (sh, Int) e)
marr
(IRDelayed PTX aenv (Segments i)
arrSeg, [Parameter]
paramSeg) = Name (Segments i)
-> MIRDelayed PTX aenv (Segments i)
-> (IRDelayed PTX aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Segments i)
"seg" MIRDelayed PTX aenv (Segments i)
mseg
paramEnv :: [Parameter]
paramEnv = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
config :: LaunchConfig
config = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.decWarp DeviceProperties
dev) Int -> Int
dsmem Int -> Int -> Int
grid Q (TExp (Int -> Int -> Int))
gridQ
dsmem :: Int -> Int
dsmem Int
n = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
per_warp_bytes
where
warps :: Int
warps = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
grid :: Int -> Int -> Int
grid Int
n Int
m = Int -> Int -> Int
multipleOf Int
n (Int
m Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws)
gridQ :: Q (TExp (Int -> Int -> Int))
gridQ = [|| \n m -> $$multipleOfQ n (m `P.quot` ws) ||]
per_warp_bytes :: Int
per_warp_bytes = (Int
per_warp_elems Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`P.max` (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp)
per_warp_elems :: Int
per_warp_elems = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2)
ws :: Int
ws = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
int32 :: Integral a => a -> Operands Int32
int32 :: a -> Operands Int32
int32 = Int32 -> Operands Int32
liftInt32 (Int32 -> Operands Int32) -> (a -> Int32) -> a -> Operands Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldSeg_warp" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do
Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
Operands Int32
wid <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
tid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
ws)
Operands Int32
bd <- CodeGen PTX (Operands Int32)
blockDim
Operands Int32
wpb <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
bd (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
ws)
Operands Int32
bid <- CodeGen PTX (Operands Int32)
blockIdx
Operands Int32
gwid <- do Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bid Operands Int32
wpb
Operands Int32
b <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid Operands Int32
a
Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
b
IRArray (Vector Int)
lim <- do
Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_bytes)
IRArray (Vector Int)
b <- TypeR Int
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector Int))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) IntegralType Int32
TypeInt32 (Int32 -> Operands Int32
liftInt32 Int32
2) Operands Int32
a
IRArray (Vector Int) -> CodeGen PTX (IRArray (Vector Int))
forall (m :: * -> *) a. Monad m => a -> m a
return IRArray (Vector Int)
b
IRArray (Vector e)
smem <- do
Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_bytes)
IRArray (Vector e)
b <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_elems) Operands Int32
a
IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return IRArray (Vector e)
b
Operands Int
sz <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn
Operands Int
ss <- do Operands Int
a <- Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands ((), Int) -> Operands Int)
-> CodeGen PTX (Operands ((), Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Segments i) -> CodeGen PTX (Operands ((), Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Segments i)
arrSeg
Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
a (Int -> Operands Int
liftInt Int
1)
Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
b
Operands Int
s0 <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
gwid
Operands Int
gd <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
Operands Int
wpb' <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
wpb
Operands Int
step <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
wpb' Operands Int
gd
Operands Int
end <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut)
Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
s0 Operands Int
step Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
s -> do
CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp
Operands Int32
lane <- CodeGen PTX (Operands Int32)
laneId
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
2)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
Operands Int
a <- case ShapeR sh
shr of
ShapeRsnoc ShapeR sh1
ShapeRz -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
s
ShapeR sh
_ -> IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.rem IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
a (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
Operands i
c <- IROpenFun1 PTX () aenv (Int -> i)
-> Operands Int -> IROpenExp PTX ((), Int) aenv i
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Segments i)
-> IROpenFun1 PTX () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Segments i)
arrSeg) Operands Int
b
IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> Operands Int
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim Operands Int32
lane (Operands Int -> CodeGen PTX ())
-> CodeGen PTX (Operands Int) -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
intTp NumType Int
forall a. IsNum a => NumType a
numType Operands i
c
CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp
(Operands Int
inf,Operands Int
sup) <- do
Operands Int
u <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim (Int32 -> Operands Int32
liftInt32 Int32
0)
Operands Int
v <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim (Int32 -> Operands Int32
liftInt32 Int32
1)
Operands (Int, Int) -> (Operands Int, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (Int, Int) -> (Operands Int, Operands Int))
-> CodeGen PTX (Operands (Int, Int))
-> CodeGen PTX (Operands Int, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ShapeR sh
shr of
ShapeRsnoc ShapeR sh1
ShapeRz -> Operands (Int, Int) -> CodeGen PTX (Operands (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair Operands Int
u Operands Int
v)
ShapeR sh
_ -> do Operands Int
q <- IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
q Operands Int
sz
Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair (Operands Int -> Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands Int -> Operands (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
u Operands Int
a
CodeGen PTX (Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands (Int, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
v Operands Int
a
CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp
CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
if (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
inf Operands Int
sup)
then do
case MIRExp PTX aenv e
mseed of
MIRExp PTX aenv e
Nothing -> Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
Just IRExp PTX aenv e
z -> do
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
else do
Operands Int
i0 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
Operands e
x0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i0 Operands Int
sup)
then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
else let
go :: TypeR a -> Operands a
go :: TypeR a -> Operands a
go TypeR a
TupRunit = Operands a
Operands ()
OP_Unit
go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
in
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp
Operands Int
v0 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
inf
Operands Int32
v0' <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v0
Operands e
r0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v0 (Int -> Operands Int
liftInt Int
ws))
then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
x0
else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v0') Operands e
x0
Operands Int
nx <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Int -> Operands Int
liftInt Int
ws)
Operands e
r <- TypeR e
-> Operands Int
-> Operands Int
-> Operands Int
-> Operands e
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp Operands Int
nx (Int -> Operands Int
liftInt Int
ws) Operands Int
sup Operands e
r0 ((Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e))
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset Operands e
r -> do
CodeGen PTX ()
__syncthreads
Operands Int
i' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
Operands Int
v' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
offset
Operands e
r' <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v' (Int -> Operands Int
liftInt Int
ws))
then do
Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
Operands e
x
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y
else do
Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i' Operands Int
sup)
then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
else let
go :: TypeR a -> Operands a
go :: TypeR a -> Operands a
go TypeR a
TupRunit = Operands a
Operands ()
OP_Unit
go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
in
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp
Operands Int32
z <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v'
Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
z) Operands e
x
Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y
if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0))
then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
r Operands e
Operands e
r'
else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r'
CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
case MIRExp PTX aenv e
mseed of
MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
Just IRExp PTX aenv e
z -> (Operands e -> Operands e -> IRExp PTX aenv e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
Operands e
r (Operands e -> IRExp PTX aenv e)
-> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_
i32 :: IsIntegral i => Operands i -> CodeGen PTX (Operands Int32)
i32 :: Operands i -> CodeGen PTX (Operands Int32)
i32 = IntegralType i
-> NumType Int32 -> Operands i -> CodeGen PTX (Operands Int32)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
forall a. IsIntegral a => IntegralType a
integralType NumType Int32
forall a. IsNum a => NumType a
numType
int :: IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int :: Operands i -> CodeGen PTX (Operands Int)
int = IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
forall a. IsIntegral a => IntegralType a
integralType NumType Int
forall a. IsNum a => NumType a
numType