{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan (
mkScanl, mkScanl1, mkScanl',
mkScanr, mkScanr1, mkScanr',
) where
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.Representation
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Applicative
import Control.Monad ( (>=>), void )
import Data.String ( fromString )
import Data.Coerce as Safe
import Data.Bits as P
import Prelude as P hiding ( last )
data Direction = L | R
mkScanl
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine (Just seed) arr
, mkScanAllP2 L dev aenv combine
, mkScanAllP3 L dev aenv combine (Just seed)
, mkScanFill ptx aenv seed
]
| otherwise
= (+++) <$> mkScanDim L dev aenv combine (Just seed) arr
<*> mkScanFill ptx aenv seed
mkScanl1
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl1 (deviceProperties . ptxContext -> dev) aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine Nothing arr
, mkScanAllP2 L dev aenv combine
, mkScanAllP3 L dev aenv combine Nothing
]
| otherwise
= mkScanDim L dev aenv combine Nothing arr
mkScanl'
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanl' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'AllP1 L dev aenv combine seed arr
, mkScan'AllP2 L dev aenv combine
, mkScan'AllP3 L dev aenv combine
, mkScan'Fill ptx aenv seed
]
| otherwise
= (+++) <$> mkScan'Dim L dev aenv combine seed arr
<*> mkScan'Fill ptx aenv seed
mkScanr
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine (Just seed) arr
, mkScanAllP2 R dev aenv combine
, mkScanAllP3 R dev aenv combine (Just seed)
, mkScanFill ptx aenv seed
]
| otherwise
= (+++) <$> mkScanDim R dev aenv combine (Just seed) arr
<*> mkScanFill ptx aenv seed
mkScanr1
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr1 (deviceProperties . ptxContext -> dev) aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine Nothing arr
, mkScanAllP2 R dev aenv combine
, mkScanAllP3 R dev aenv combine Nothing
]
| otherwise
= mkScanDim R dev aenv combine Nothing arr
mkScanr'
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanr' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'AllP1 R dev aenv combine seed arr
, mkScan'AllP2 R dev aenv combine
, mkScan'AllP3 R dev aenv combine
, mkScan'Fill ptx aenv seed
]
| otherwise
= (+++) <$> mkScan'Dim R dev aenv combine seed arr
<*> mkScan'Fill ptx aenv seed
mkScanAllP1
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 dir dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
sz <- indexHead <$> delayedExtent
bid <- blockIdx
gd <- gridDim
gd' <- int gd
s0 <- A.add numType start =<< int bid
imapFromStepTo s0 gd' end $ \chunk -> do
bd <- blockDim
bd' <- int bd
inf <- A.mul numType chunk bd'
tid <- threadIdx
tid' <- int tid
i0 <- case dir of
L -> A.add numType inf tid'
R -> do x <- A.sub numType sz inf
y <- A.sub numType x tid'
z <- A.sub numType y (lift 1)
return z
j0 <- case mseed of
Nothing -> return i0
Just _ -> case dir of
L -> A.add numType i0 (lift 1)
R -> return i0
let valid i = case dir of
L -> A.lt singleType i sz
R -> A.gte singleType i (lift 0)
when (valid i0) $ do
x0 <- app1 delayedLinearIndex i0
x1 <- case mseed of
Nothing -> return x0
Just seed ->
if A.eq singleType tid (lift 0) `A.land` A.eq singleType chunk (lift 0)
then do
z <- seed
case dir of
L -> writeArray arrOut (lift 0 :: IR Int32) z >> app2 combine z x0
R -> writeArray arrOut sz z >> app2 combine x0 z
else
return x0
n <- A.sub numType sz inf
n' <- i32 n
x2 <- if A.gte singleType n bd'
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n') x1
writeArray arrOut j0 x2
last <- A.sub numType bd (lift 1)
when (A.gt singleType gd (lift 1) `land` A.eq singleType tid last) $
case dir of
L -> writeArray arrTmp chunk x2
R -> do u <- A.sub numType end chunk
v <- A.sub numType u (lift 1)
writeArray arrTmp v x2
return_
mkScanAllP2
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
grid _ _ = 1
gridQ = [|| \_ _ -> 1 ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do
carry <- staticSharedMem 1
bd <- blockDim
bd' <- int bd
imapFromStepTo start bd' end $ \offset -> do
tid <- threadIdx
tid' <- int tid
i0 <- case dir of
L -> A.add numType offset tid'
R -> do x <- A.sub numType end offset
y <- A.sub numType x tid'
z <- A.sub numType y (lift 1)
return z
let valid i = case dir of
L -> A.lt singleType i end
R -> A.gte singleType i start
when (valid i0) $ do
__syncthreads
x0 <- readArray arrTmp i0
x1 <- if A.gt singleType offset (lift 0) `land` A.eq singleType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x0
R -> app2 combine x0 c
else do
return x0
n <- A.sub numType end offset
n' <- i32 n
x2 <- if A.gte singleType n bd'
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n') x1
writeArray arrTmp i0 x2
last <- A.sub numType bd (lift 1)
when (A.eq singleType tid last) $
writeArray carry (lift 0 :: IR Int32) x2
return_
mkScanAllP3
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 dir dev aenv combine mseed =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
in
makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do
sz <- return $ indexHead (irArrayShape arrOut)
tid <- int =<< threadIdx
when (A.lt singleType tid stride) $ do
bid <- int =<< blockIdx
gd <- int =<< gridDim
c0 <- A.add numType start bid
imapFromStepTo c0 gd end $ \chunk -> do
(inf,sup) <- case dir of
L -> do
a <- A.add numType chunk (lift 1)
b <- A.mul numType stride a
case mseed of
Just{} -> do
c <- A.add numType b (lift 1)
d <- A.add numType c stride
e <- A.min singleType d sz
return (c,e)
Nothing -> do
c <- A.add numType b stride
d <- A.min singleType c sz
return (b,d)
R -> do
a <- A.sub numType end chunk
b <- A.mul numType stride a
c <- A.sub numType sz b
case mseed of
Just{} -> do
d <- A.sub numType c (lift 1)
e <- A.sub numType d stride
f <- A.max singleType e (lift 0)
return (f,d)
Nothing -> do
d <- A.sub numType c stride
e <- A.max singleType d (lift 0)
return (e,c)
carry <- case dir of
L -> readArray arrTmp chunk
R -> do
a <- A.add numType chunk (lift 1)
b <- readArray arrTmp a
return b
bd <- int =<< blockDim
i0 <- A.add numType inf tid
imapFromStepTo i0 bd sup $ \i -> do
v <- readArray arrOut i
u <- case dir of
L -> app2 combine carry v
R -> app2 combine v carry
writeArray arrOut i u
return_
mkScan'AllP1
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 dir dev aenv combine seed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
sz <- indexHead <$> delayedExtent
bid <- int =<< blockIdx
gd <- int =<< gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
bd <- int =<< blockDim
inf <- A.mul numType seg bd
tid <- int =<< threadIdx
i0 <- case dir of
L -> A.add numType inf tid
R -> do x <- A.sub numType sz inf
y <- A.sub numType x tid
z <- A.sub numType y (lift 1)
return z
j0 <- case dir of
L -> A.add numType i0 (lift 1)
R -> A.sub numType i0 (lift 1)
let valid i = case dir of
L -> A.lt singleType i sz
R -> A.gte singleType i (lift 0)
when (valid i0) $ do
x0 <- app1 delayedLinearIndex i0
ti <- threadIdx
x1 <- if A.eq singleType ti (lift 0) `A.land` A.eq singleType seg (lift 0)
then do
z <- seed
writeArray arrOut i0 z
case dir of
L -> app2 combine z x0
R -> app2 combine x0 z
else
return x0
n <- A.sub numType sz inf
n' <- i32 n
x2 <- if A.gte singleType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n') x1
case dir of
L -> when (A.lt singleType j0 sz) $ writeArray arrOut j0 x2
R -> when (A.gte singleType j0 (lift 0)) $ writeArray arrOut j0 x2
m <- do x <- A.min singleType n bd
y <- A.sub numType x (lift 1)
return y
when (A.eq singleType tid m) $
case dir of
L -> writeArray arrTmp seg x2
R -> do x <- A.sub numType end seg
y <- A.sub numType x (lift 1)
writeArray arrTmp y x2
return_
mkScan'AllP2
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Scalar e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
grid _ _ = 1
gridQ = [|| \_ _ -> 1 ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramSum ++ paramEnv) $ do
carry <- staticSharedMem 1
tid <- threadIdx
tid' <- int tid
bd <- int =<< blockDim
imapFromStepTo start bd end $ \offset -> do
i0 <- case dir of
L -> A.add numType offset tid'
R -> do x <- A.sub numType end offset
y <- A.sub numType x tid'
z <- A.sub numType y (lift 1)
return z
let valid i = case dir of
L -> A.lt singleType i end
R -> A.gte singleType i start
when (valid i0) $ do
__syncthreads
x0 <- readArray arrTmp i0
x1 <- if A.gt singleType offset (lift 0) `A.land` A.eq singleType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x0
R -> app2 combine x0 c
else
return x0
n <- A.sub numType end offset
n' <- i32 n
x2 <- if A.gte singleType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n') x1
writeArray arrTmp i0 x2
m <- do x <- A.min singleType bd n
y <- A.sub numType x (lift 1)
z <- i32 y
return z
when (A.eq singleType tid m) $
writeArray carry (lift 0 :: IR Int32) x2
__syncthreads
when (A.eq singleType tid (lift 0)) $
writeArray arrSum (lift 0 :: IR Int32) =<< readArray carry (lift 0 :: IR Int32)
return_
mkScan'AllP3
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
in
makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do
sz <- return $ indexHead (irArrayShape arrOut)
tid <- int =<< threadIdx
when (A.lt singleType tid stride) $ do
bid <- int =<< blockIdx
gd <- int =<< gridDim
c0 <- A.add numType start bid
imapFromStepTo c0 gd end $ \chunk -> do
(inf,sup) <- case dir of
L -> do
a <- A.add numType chunk (lift 1)
b <- A.mul numType stride a
c <- A.add numType b (lift 1)
d <- A.add numType c stride
e <- A.min singleType d sz
return (c,e)
R -> do
a <- A.sub numType end chunk
b <- A.mul numType stride a
c <- A.sub numType sz b
d <- A.sub numType c (lift 1)
e <- A.sub numType d stride
f <- A.max singleType e (lift 0)
return (f,d)
carry <- case dir of
L -> readArray arrTmp chunk
R -> do
a <- A.add numType chunk (lift 1)
b <- readArray arrTmp a
return b
bd <- int =<< blockDim
i0 <- A.add numType inf tid
imapFromStepTo i0 bd sup $ \i -> do
v <- readArray arrOut i
u <- case dir of
L -> app2 combine carry v
R -> app2 combine v carry
writeArray arrOut i u
return_
mkScanDim
:: forall aenv sh e. (Shape sh, Elt e)
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanDim dir dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramEnv) $ do
carry <- staticSharedMem 1
sz <- indexHead <$> delayedExtent
bid <- int =<< blockIdx
gd <- int =<< gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
tid <- threadIdx
tid' <- int tid
i0 <- case dir of
L -> do x <- A.mul numType seg sz
y <- A.add numType x tid'
return y
R -> do x <- A.add numType seg (lift 1)
y <- A.mul numType x sz
z <- A.sub numType y tid'
w <- A.sub numType z (lift 1)
return w
j0 <- case mseed of
Nothing -> return i0
Just{} -> do szp1 <- return $ indexHead (irArrayShape arrOut)
case dir of
L -> do x <- A.mul numType seg szp1
y <- A.add numType x tid'
return y
R -> do x <- A.add numType seg (lift 1)
y <- A.mul numType x szp1
z <- A.sub numType y tid'
w <- A.sub numType z (lift 1)
return w
bd <- blockDim
bd' <- int bd
let next ix = case dir of
L -> A.add numType ix bd'
R -> A.sub numType ix bd'
r <-
case mseed of
Just seed -> do
when (A.eq singleType tid (lift 0)) $ do
z <- seed
writeArray arrOut j0 z
writeArray carry (lift 0 :: IR Int32) z
j1 <- case dir of
L -> A.add numType j0 (lift 1)
R -> A.sub numType j0 (lift 1)
return $ A.trip sz i0 j1
Nothing -> do
when (A.lt singleType tid' sz) $ do
n' <- i32 sz
x0 <- app1 delayedLinearIndex i0
r0 <- if A.gte singleType sz bd'
then scanBlockSMem dir dev combine Nothing x0
else scanBlockSMem dir dev combine (Just n') x0
writeArray arrOut j0 r0
ll <- A.sub numType bd (lift 1)
when (A.eq singleType tid ll) $
writeArray carry (lift 0 :: IR Int32) r0
n1 <- A.sub numType sz bd'
i1 <- next i0
j1 <- next j0
return $ A.trip n1 i1 j1
void $ while
(\(A.fst3 -> n) -> A.gt singleType n (lift 0))
(\(A.untrip -> (n,i,j)) -> do
__syncthreads
x <- if A.lt singleType tid' n
then app1 delayedLinearIndex i
else let
go :: TupleType a -> Operands a
go TypeRunit = OP_Unit
go (TypeRpair a b) = OP_Pair (go a) (go b)
go (TypeRscalar t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
y <- if A.eq singleType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
m <- i32 n
z <- if A.gte singleType n bd'
then scanBlockSMem dir dev combine Nothing y
else scanBlockSMem dir dev combine (Just m) y
when (A.lt singleType tid' n) $ do
writeArray arrOut j z
w <- A.sub numType bd (lift 1)
when (A.eq singleType tid w) $
writeArray carry (lift 0 :: IR Int32) z
n' <- A.sub numType n bd'
i' <- next i
j' <- next j
return $ A.trip n' i' j')
r
return_
mkScan'Dim
:: forall aenv sh e. (Shape sh, Elt e)
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Dim dir dev aenv combine seed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Array sh e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do
carry <- staticSharedMem 1
sz <- indexHead <$> delayedExtent
tid <- threadIdx
tid' <- int tid
when (A.lte singleType tid' sz) $ do
bid <- int =<< blockIdx
gd <- int =<< gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
inf <- A.mul numType seg sz
sup <- A.add numType inf sz
i0 <- case dir of
L -> A.add numType inf tid'
R -> do x <- A.sub numType sup tid'
y <- A.sub numType x (lift 1)
return y
j0 <- case dir of
L -> A.add numType i0 (lift 1)
R -> A.sub numType i0 (lift 1)
when (A.eq singleType tid (lift 0)) $ do
z <- seed
writeArray arrOut i0 z
writeArray carry (lift 0 :: IR Int32) z
bd <- blockDim
bd' <- int bd
let next ix = case dir of
L -> A.add numType ix bd'
R -> A.sub numType ix bd'
n0 <- A.sub numType sup inf
void $ while
(\(A.fst3 -> n) -> A.gt singleType n (lift 0))
(\(A.untrip -> (n,i,j)) -> do
__syncthreads
_ <- if A.gte singleType n bd'
then do
x <- app1 delayedLinearIndex i
y <- if A.eq singleType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
z <- scanBlockSMem dir dev combine Nothing y
case dir of
L -> when (A.lt singleType j sup) $ writeArray arrOut j z
R -> when (A.gte singleType j inf) $ writeArray arrOut j z
bd1 <- A.sub numType bd (lift 1)
when (A.eq singleType tid bd1) $
writeArray carry (lift 0 :: IR Int32) z
return (IR OP_Unit :: IR ())
else do
when (A.lt singleType tid' n) $ do
x <- app1 delayedLinearIndex i
y <- if A.eq singleType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
l <- i32 n
z <- scanBlockSMem dir dev combine (Just l) y
m <- A.sub numType n (lift 1)
_ <- if A.lt singleType tid' m
then writeArray arrOut j z >> return (IR OP_Unit :: IR ())
else writeArray carry (lift 0 :: IR Int32) z >> return (IR OP_Unit :: IR ())
return ()
return (IR OP_Unit :: IR ())
A.trip <$> A.sub numType n bd' <*> next i <*> next j)
(A.trip n0 i0 j0)
__syncthreads
when (A.eq singleType tid (lift 0)) $
writeArray arrSum seg =<< readArray carry (lift 0 :: IR Int32)
return_
mkScanFill
:: (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRExp PTX aenv e
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkScanFill ptx aenv seed =
mkGenerate ptx aenv (IRFun1 (const seed))
mkScan'Fill
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRExp PTX aenv e
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill ptx aenv seed =
Safe.coerce <$> (mkGenerate ptx aenv (IRFun1 (const seed)) :: CodeGen (IROpenAcc PTX aenv (Array sh e)))
scanBlockSMem
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IR Int32)
-> IR e
-> CodeGen (IR e)
scanBlockSMem dir dev combine nelem = warpScan >=> warpPrefix
where
int32 :: Integral a => a -> IR Int32
int32 = lift . P.fromIntegral
warp_smem_elems = CUDA.warpSize dev + (CUDA.warpSize dev `P.quot` 2)
warp_smem_bytes = warp_smem_elems * sizeOf (eltType (undefined::e))
warpScan :: IR e -> CodeGen (IR e)
warpScan input = do
wid <- warpId
skip <- A.mul numType wid (int32 warp_smem_bytes)
smem <- dynamicSharedMem (int32 warp_smem_elems) skip
scanWarpSMem dir dev combine smem input
warpPrefix :: IR e -> CodeGen (IR e)
warpPrefix input = do
bd <- blockDim
warps <- A.quot integralType bd (int32 (CUDA.warpSize dev))
skip <- A.mul numType warps (int32 warp_smem_bytes)
smem <- dynamicSharedMem warps skip
wid <- warpId
lane <- laneId
when (A.eq singleType lane (int32 (CUDA.warpSize dev - 1))) $ do
writeArray smem wid input
__syncthreads
if A.eq singleType wid (lift 0)
then return input
else do
steps <- case nelem of
Nothing -> return wid
Just n -> A.min singleType wid =<< A.quot integralType n (int32 (CUDA.warpSize dev))
p0 <- readArray smem (lift 0 :: IR Int32)
prefix <- iterFromStepTo (lift 1) (lift 1) steps p0 $ \step x -> do
y <- readArray smem step
case dir of
L -> app2 combine x y
R -> app2 combine y x
case dir of
L -> app2 combine prefix input
R -> app2 combine input prefix
scanWarpSMem
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> IR e
-> CodeGen (IR e)
scanWarpSMem dir dev combine smem = scan 0
where
log2 :: Double -> Double
log2 = P.logBase 2
steps = P.floor (log2 (P.fromIntegral (CUDA.warpSize dev)))
halfWarp = P.fromIntegral (CUDA.warpSize dev `P.quot` 2)
scan :: Int -> IR e -> CodeGen (IR e)
scan step x
| step >= steps = return x
| offset <- 1 `P.shiftL` step = do
lane <- laneId
i <- A.add numType lane (lift halfWarp)
writeArray smem i x
x' <- if A.gte singleType lane (lift offset)
then do
i' <- A.sub numType i (lift offset)
x' <- readArray smem i'
case dir of
L -> app2 combine x' x
R -> app2 combine x x'
else
return x
scan (step+1) x'
i32 :: IR Int -> CodeGen (IR Int32)
i32 = A.fromIntegral integralType numType
int :: IR Int32 -> CodeGen (IR Int)
int = A.fromIntegral integralType numType