{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.FoldSeg
where
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar ( Array, Segments, Shape(rank), (:.), Elt(..) )
import LLVM.AST.Type.Representation
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold ( reduceBlockSMem, reduceWarpSMem, imapFromTo )
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Applicative ( (<$>), (<*>) )
import Control.Monad ( void )
import Data.String ( fromString )
import Prelude as P
mkFoldSeg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSeg (deviceProperties . ptxContext -> dev) aenv combine seed arr seg =
(+++) <$> mkFoldSegP_block dev aenv combine (Just seed) arr seg
<*> mkFoldSegP_warp dev aenv combine (Just seed) arr seg
mkFold1Seg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFold1Seg (deviceProperties . ptxContext -> dev) aenv combine arr seg =
(+++) <$> mkFoldSegP_block dev aenv combine Nothing arr seg
<*> mkFoldSegP_warp dev aenv combine Nothing arr seg
mkFoldSegP_block
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSegP_block dev aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.decWarp dev) dsmem const [|| const ||]
dsmem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "foldSeg_block" (paramGang ++ paramOut ++ paramEnv) $ do
smem <- staticSharedMem 2
sz <- indexHead <$> delayedExtent arr
ss <- do n <- indexHead <$> delayedExtent seg
A.sub numType n (lift 1)
imapFromTo start end $ \s -> do
tid <- threadIdx
when (A.lt singleType tid (lift 2)) $ do
i <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
j <- A.add numType i =<< int tid
v <- app1 (delayedLinearIndex seg) j
writeArray smem tid =<< int v
__syncthreads
u <- readArray smem (lift 0 :: IR Int32)
v <- readArray smem (lift 1 :: IR Int32)
(inf,sup) <- A.unpair <$> case rank (undefined::sh) of
0 -> return (A.pair u v)
_ -> do q <- A.quot integralType s ss
a <- A.mul numType q sz
A.pair <$> A.add numType u a
<*> A.add numType v a
void $
if A.eq singleType inf sup
then do
case mseed of
Nothing -> return (IR OP_Unit :: IR ())
Just z -> do
when (A.eq singleType tid (lift 0)) $ writeArray arrOut s =<< z
return (IR OP_Unit)
else do
i0 <- A.add numType inf =<< int tid
x0 <- if A.lt singleType i0 sup
then app1 (delayedLinearIndex arr) i0
else let
go :: TupleType a -> Operands a
go TypeRunit = OP_Unit
go (TypeRpair a b) = OP_Pair (go a) (go b)
go (TypeRscalar t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
bd <- int =<< blockDim
v0 <- A.sub numType sup inf
v0' <- i32 v0
r0 <- if A.gte singleType v0 bd
then reduceBlockSMem dev combine Nothing x0
else reduceBlockSMem dev combine (Just v0') x0
nxt <- A.add numType inf bd
r <- iterFromStepTo nxt bd sup r0 $ \offset r -> do
__syncthreads
i' <- A.add numType offset =<< int tid
v' <- A.sub numType sup offset
r' <- if A.gte singleType v' bd
then do
x <- app1 (delayedLinearIndex arr) i'
y <- reduceBlockSMem dev combine Nothing x
return y
else do
x <- if A.lt singleType i' sup
then app1 (delayedLinearIndex arr) i'
else return r
z <- i32 v'
y <- reduceBlockSMem dev combine (Just z) x
return y
if A.eq singleType tid (lift 0)
then app2 combine r r'
else return r'
when (A.eq singleType tid (lift 0)) $
writeArray arrOut s =<<
case mseed of
Nothing -> return r
Just z -> flip (app2 combine) r =<< z
return (IR OP_Unit)
return_
mkFoldSegP_warp
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSegP_warp dev aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.decWarp dev) dsmem grid gridQ
dsmem n = warps * (2 + per_warp_elems) * bytes
where
warps = n `P.quot` ws
grid n m = multipleOf n (m `P.quot` ws)
gridQ = [|| \n m -> $$multipleOfQ n (m `P.quot` ws) ||]
per_warp_bytes = per_warp_elems * bytes
per_warp_elems = ws + (ws `P.quot` 2)
ws = CUDA.warpSize dev
bytes = sizeOf (eltType (undefined :: e))
int32 :: Integral a => a -> IR Int32
int32 = lift . P.fromIntegral
in
makeOpenAccWith config "foldSeg_warp" (paramGang ++ paramOut ++ paramEnv) $ do
tid <- threadIdx
wid <- A.quot integralType tid (int32 ws)
bd <- blockDim
wpb <- A.quot integralType bd (int32 ws)
bid <- blockIdx
gwid <- do a <- A.mul numType bid wpb
b <- A.add numType wid a
return b
lim <- do
a <- A.mul numType wid (int32 (2 * bytes))
b <- dynamicSharedMem (lift 2) a
return b
smem <- do
a <- A.mul numType wpb (int32 (2 * bytes))
b <- A.mul numType wid (int32 per_warp_bytes)
c <- A.add numType a b
d <- dynamicSharedMem (int32 per_warp_elems) c
return d
sz <- indexHead <$> delayedExtent arr
ss <- do a <- indexHead <$> delayedExtent seg
b <- A.sub numType a (lift 1)
return b
s0 <- A.add numType start =<< int gwid
gd <- int =<< gridDim
wpb' <- int wpb
step <- A.mul numType wpb' gd
imapFromStepTo s0 step end $ \s -> do
lane <- laneId
when (A.lt singleType lane (lift 2)) $ do
a <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
b <- A.add numType a =<< int lane
c <- app1 (delayedLinearIndex seg) b
writeArray lim lane =<< int c
(inf,sup) <- do
u <- readArray lim (lift 0 :: IR Int32)
v <- readArray lim (lift 1 :: IR Int32)
A.unpair <$> case rank (undefined::sh) of
0 -> return (A.pair u v)
_ -> do q <- A.quot integralType s ss
a <- A.mul numType q sz
A.pair <$> A.add numType u a
<*> A.add numType v a
__syncthreads
void $
if A.eq singleType inf sup
then do
case mseed of
Nothing -> return (IR OP_Unit :: IR ())
Just z -> do
when (A.eq singleType lane (lift 0)) $ writeArray arrOut s =<< z
return (IR OP_Unit)
else do
i0 <- A.add numType inf =<< int lane
x0 <- if A.lt singleType i0 sup
then app1 (delayedLinearIndex arr) i0
else let
go :: TupleType a -> Operands a
go TypeRunit = OP_Unit
go (TypeRpair a b) = OP_Pair (go a) (go b)
go (TypeRscalar t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
v0 <- A.sub numType sup inf
v0' <- i32 v0
r0 <- if A.gte singleType v0 (lift ws)
then reduceWarpSMem dev combine smem Nothing x0
else reduceWarpSMem dev combine smem (Just v0') x0
nx <- A.add numType inf (lift ws)
r <- iterFromStepTo nx (lift ws) sup r0 $ \offset r -> do
__syncthreads
i' <- A.add numType offset =<< int lane
v' <- A.sub numType sup offset
r' <- if A.gte singleType v' (lift ws)
then do
x <- app1 (delayedLinearIndex arr) i'
y <- reduceWarpSMem dev combine smem Nothing x
return y
else do
x <- if A.lt singleType i' sup
then app1 (delayedLinearIndex arr) i'
else return r
z <- i32 v'
y <- reduceWarpSMem dev combine smem (Just z) x
return y
if A.eq singleType lane (lift 0)
then app2 combine r r'
else return r'
when (A.eq singleType lane (lift 0)) $
writeArray arrOut s =<<
case mseed of
Nothing -> return r
Just z -> flip (app2 combine) r =<< z
return (IR OP_Unit)
return_
i32 :: IsIntegral i => IR i -> CodeGen (IR Int32)
i32 = A.fromIntegral integralType numType
int :: IsIntegral i => IR i -> CodeGen (IR Int)
int = A.fromIntegral integralType numType