{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan -- Copyright : [2016..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan ( mkScanl, mkScanl1, mkScanl', mkScanr, mkScanr1, mkScanr', ) where -- accelerate import Data.Array.Accelerate.Analysis.Type import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.LLVM.Analysis.Match import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A import Data.Array.Accelerate.LLVM.CodeGen.Array import Data.Array.Accelerate.LLVM.CodeGen.Base import Data.Array.Accelerate.LLVM.CodeGen.Constant import Data.Array.Accelerate.LLVM.CodeGen.Environment import Data.Array.Accelerate.LLVM.CodeGen.Exp import Data.Array.Accelerate.LLVM.CodeGen.IR import Data.Array.Accelerate.LLVM.CodeGen.Loop import Data.Array.Accelerate.LLVM.CodeGen.Monad import Data.Array.Accelerate.LLVM.CodeGen.Sugar import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate import Data.Array.Accelerate.LLVM.PTX.Context import Data.Array.Accelerate.LLVM.PTX.Target import LLVM.AST.Type.Representation import qualified Foreign.CUDA.Analysis as CUDA import Control.Applicative import Control.Monad ( (>=>), void ) import Data.String ( fromString ) import Data.Coerce as Safe import Data.Bits as P import Prelude as P hiding ( last ) data Direction = L | R -- 'Data.List.scanl' style left-to-right exclusive scan, but with the -- restriction that the combination function must be associative to enable -- efficient parallel implementation. -- -- > scanl (+) 10 (use $ fromList (Z :. 10) [0..]) -- > -- > ==> Array (Z :. 11) [10,10,11,13,16,20,25,31,38,46,55] -- mkScanl :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRExp PTX aenv e -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e)) mkScanl ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine (Just seed) arr , mkScanAllP2 L dev aenv combine , mkScanAllP3 L dev aenv combine (Just seed) , mkScanFill ptx aenv seed ] -- | otherwise = (+++) <$> mkScanDim L dev aenv combine (Just seed) arr <*> mkScanFill ptx aenv seed -- 'Data.List.scanl1' style left-to-right inclusive scan, but with the -- restriction that the combination function must be associative to enable -- efficient parallel implementation. The array must not be empty. -- -- > scanl1 (+) (use $ fromList (Z :. 10) [0..]) -- > -- > ==> Array (Z :. 10) [0,1,3,6,10,15,21,28,36,45] -- mkScanl1 :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e)) mkScanl1 (deviceProperties . ptxContext -> dev) aenv combine arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine Nothing arr , mkScanAllP2 L dev aenv combine , mkScanAllP3 L dev aenv combine Nothing ] -- | otherwise = mkScanDim L dev aenv combine Nothing arr -- Variant of 'scanl' where the final result is returned in a separate array. -- -- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..]) -- > -- > ==> ( Array (Z :. 10) [10,10,11,13,16,20,25,31,38,46] -- , Array Z [55] -- ) -- mkScanl' :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRExp PTX aenv e -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e)) mkScanl' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScan'AllP1 L dev aenv combine seed arr , mkScan'AllP2 L dev aenv combine , mkScan'AllP3 L dev aenv combine , mkScan'Fill ptx aenv seed ] -- | otherwise = (+++) <$> mkScan'Dim L dev aenv combine seed arr <*> mkScan'Fill ptx aenv seed -- 'Data.List.scanr' style right-to-left exclusive scan, but with the -- restriction that the combination function must be associative to enable -- efficient parallel implementation. -- -- > scanr (+) 10 (use $ fromList (Z :. 10) [0..]) -- > -- > ==> Array (Z :. 11) [55,55,54,52,49,45,40,34,27,19,10] -- mkScanr :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRExp PTX aenv e -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e)) mkScanr ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine (Just seed) arr , mkScanAllP2 R dev aenv combine , mkScanAllP3 R dev aenv combine (Just seed) , mkScanFill ptx aenv seed ] -- | otherwise = (+++) <$> mkScanDim R dev aenv combine (Just seed) arr <*> mkScanFill ptx aenv seed -- 'Data.List.scanr1' style right-to-left inclusive scan, but with the -- restriction that the combination function must be associative to enable -- efficient parallel implementation. The array must not be empty. -- -- > scanr (+) 10 (use $ fromList (Z :. 10) [0..]) -- > -- > ==> Array (Z :. 10) [45,45,44,42,39,35,30,24,17,9] -- mkScanr1 :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e)) mkScanr1 (deviceProperties . ptxContext -> dev) aenv combine arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine Nothing arr , mkScanAllP2 R dev aenv combine , mkScanAllP3 R dev aenv combine Nothing ] -- | otherwise = mkScanDim R dev aenv combine Nothing arr -- Variant of 'scanr' where the final result is returned in a separate array. -- -- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..]) -- > -- > ==> ( Array (Z :. 10) [55,54,52,49,45,40,34,27,19,10] -- , Array Z [55] -- ) -- mkScanr' :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRExp PTX aenv e -> IRDelayed PTX aenv (Array (sh:.Int) e) -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e)) mkScanr' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = foldr1 (+++) <$> sequence [ mkScan'AllP1 R dev aenv combine seed arr , mkScan'AllP2 R dev aenv combine , mkScan'AllP3 R dev aenv combine , mkScan'Fill ptx aenv seed ] -- | otherwise = (+++) <$> mkScan'Dim R dev aenv combine seed arr <*> mkScan'Fill ptx aenv seed -- Device wide scans -- ----------------- -- -- This is a classic two-pass algorithm which proceeds in two phases and -- requires ~4n data movement to global memory. In future we would like to -- replace this with a single pass algorithm. -- -- Parallel scan, step 1. -- -- Threads scan a stripe of the input into a temporary array, incorporating the -- initial element and any fused functions on the way. The final reduction -- result of this chunk is written to a separate array. -- mkScanAllP1 :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> Maybe (IRExp PTX aenv e) -- ^ seed element, if this is an exclusive scan -> IRDelayed PTX aenv (Vector e) -- ^ input data -> CodeGen (IROpenAcc PTX aenv (Vector e)) mkScanAllP1 dir dev aenv combine mseed IRDelayed{..} = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Vector e)) (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do -- Size of the input array sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent -- A thread block scans a non-empty stripe of the input, storing the final -- block-wide aggregate into a separate array -- -- For exclusive scans, thread 0 of segment 0 must incorporate the initial -- element into the input and output. Threads shuffle their indices -- appropriately. -- bid <- blockIdx gd <- gridDim s0 <- A.add numType start bid -- iterating over thread-block-wide segments imapFromStepTo s0 gd end $ \chunk -> do bd <- blockDim inf <- A.mul numType chunk bd -- index i* is the index that this thread will read data from. Recall that -- the supremum index is exclusive tid <- threadIdx i0 <- case dir of L -> A.add numType inf tid R -> do x <- A.sub numType sz inf y <- A.sub numType x tid z <- A.sub numType y (lift 1) return z -- index j* is the index that we write to. Recall that for exclusive scans -- the output array is one larger than the input; the initial element will -- be written into this spot by thread 0 of the first thread block. j0 <- case mseed of Nothing -> return i0 Just _ -> case dir of L -> A.add numType i0 (lift 1) R -> return i0 -- If this thread has input, read data and participate in thread-block scan let valid i = case dir of L -> A.lt scalarType i sz R -> A.gte scalarType i (lift 0) when (valid i0) $ do x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0 x1 <- case mseed of Nothing -> return x0 Just seed -> if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType chunk (lift 0) then do z <- seed case dir of L -> writeArray arrOut (lift 0 :: IR Int32) z >> app2 combine z x0 R -> writeArray arrOut sz z >> app2 combine x0 z else return x0 n <- A.sub numType sz inf x2 <- if A.gte scalarType n bd then scanBlockSMem dir dev combine Nothing x1 else scanBlockSMem dir dev combine (Just n) x1 -- Write this thread's scan result to memory writeArray arrOut j0 x2 -- The last thread also writes its result---the aggregate for this -- thread block---to the temporary partial sums array. This is only -- necessary for full blocks in a multi-block scan; the final -- partially-full tile does not have a successor block. last <- A.sub numType bd (lift 1) when (A.gt scalarType gd (lift 1) `land` A.eq scalarType tid last) $ case dir of L -> writeArray arrTmp chunk x2 R -> do u <- A.sub numType end chunk v <- A.sub numType u (lift 1) writeArray arrTmp v x2 return_ -- Parallel scan, step 2 -- -- A single thread block performs a scan of the per-block aggregates computed in -- step 1. This gives the per-block prefix which must be added to each element -- in step 3. -- mkScanAllP2 :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> CodeGen (IROpenAcc PTX aenv (Vector e)) mkScanAllP2 dir dev aenv combine = let (start, end, paramGang) = gangParam (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ grid _ _ = 1 gridQ = [|| \_ _ -> 1 ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do -- The first and last threads of the block need to communicate the -- block-wide aggregate as a carry-in value across iterations. -- -- TODO: We could optimise this a bit if we can get access to the shared -- memory area used by 'scanBlockSMem', and from there directly read the -- value computed by the last thread. carry <- staticSharedMem 1 bd <- blockDim imapFromStepTo start bd end $ \offset -> do -- Index of the partial sums array that this thread will process. tid <- threadIdx i0 <- case dir of L -> A.add numType offset tid R -> do x <- A.sub numType end offset y <- A.sub numType x tid z <- A.sub numType y (lift 1) return z let valid i = case dir of L -> A.lt scalarType i end R -> A.gte scalarType i start when (valid i0) $ do __syncthreads x0 <- readArray arrTmp i0 x1 <- if A.gt scalarType offset (lift 0) `land` A.eq scalarType tid (lift 0) then do c <- readArray carry (lift 0 :: IR Int32) case dir of L -> app2 combine c x0 R -> app2 combine x0 c else do return x0 n <- A.sub numType end offset x2 <- if A.gte scalarType n bd then scanBlockSMem dir dev combine Nothing x1 else scanBlockSMem dir dev combine (Just n) x1 -- Update the temporary array with this thread's result writeArray arrTmp i0 x2 -- The last thread writes the carry-out value. If the last thread is not -- active, then this must be the last stripe anyway. last <- A.sub numType bd (lift 1) when (A.eq scalarType tid last) $ writeArray carry (lift 0 :: IR Int32) x2 return_ -- Parallel scan, step 3. -- -- Threads combine every element of the partial block results with the carry-in -- value computed in step 2. -- mkScanAllP3 :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> Maybe (IRExp PTX aenv e) -- ^ seed element, if this is an exclusive scan -> CodeGen (IROpenAcc PTX aenv (Vector e)) mkScanAllP3 dir dev aenv combine mseed = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Vector e)) (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) paramEnv = envParam aenv -- stride = local scalarType ("ix.stride" :: Name Int32) paramStride = scalarParameter scalarType ("ix.stride" :: Name Int32) -- config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||] in makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do sz <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut)) tid <- threadIdx -- Threads that will never contribute can just exit immediately. The size of -- each chunk is set by the block dimension of the step 1 kernel, which may -- be different from the block size of this kernel. when (A.lt scalarType tid stride) $ do -- Iterate over the segments computed in phase 1. Note that we have one -- fewer chunk to process because the first has no carry-in. bid <- blockIdx gd <- gridDim c0 <- A.add numType start bid imapFromStepTo c0 gd end $ \chunk -> do -- Determine the start and end indicies of this chunk to which we will -- carry-in the value. Returned for left-to-right traversal. (inf,sup) <- case dir of L -> do a <- A.add numType chunk (lift 1) b <- A.mul numType stride a case mseed of Just{} -> do c <- A.add numType b (lift 1) d <- A.add numType c stride e <- A.min scalarType d sz return (c,e) Nothing -> do c <- A.add numType b stride d <- A.min scalarType c sz return (b,d) R -> do a <- A.sub numType end chunk b <- A.mul numType stride a c <- A.sub numType sz b case mseed of Just{} -> do d <- A.sub numType c (lift 1) e <- A.sub numType d stride f <- A.max scalarType e (lift 0) return (f,d) Nothing -> do d <- A.sub numType c stride e <- A.max scalarType d (lift 0) return (e,c) -- Read the carry-in value carry <- case dir of L -> readArray arrTmp chunk R -> do a <- A.add numType chunk (lift 1) b <- readArray arrTmp a return b -- Apply the carry-in value to each element in the chunk bd <- blockDim i0 <- A.add numType inf tid imapFromStepTo i0 bd sup $ \i -> do v <- readArray arrOut i u <- case dir of L -> app2 combine carry v R -> app2 combine v carry writeArray arrOut i u return_ -- Parallel scan', step 1. -- -- Similar to mkScanAllP1. Threads scan a stripe of the input into a temporary -- array, incorporating the initial element and any fused functions on the way. -- The final reduction result of this chunk is written to a separate array. -- mkScan'AllP1 :: forall aenv e. Elt e => Direction -> DeviceProperties -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> IRExp PTX aenv e -> IRDelayed PTX aenv (Vector e) -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP1 dir dev aenv combine seed IRDelayed{..} = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Vector e)) (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do -- Size of the input array sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent -- A thread block scans a non-empty stripe of the input, storing the partial -- result and the final block-wide aggregate bid <- blockIdx gd <- gridDim s0 <- A.add numType start bid -- iterate over thread-block wide segments imapFromStepTo s0 gd end $ \seg -> do bd <- blockDim inf <- A.mul numType seg bd -- i* is the index that this thread will read data from tid <- threadIdx i0 <- case dir of L -> A.add numType inf tid R -> do x <- A.sub numType sz inf y <- A.sub numType x tid z <- A.sub numType y (lift 1) return z -- j* is the index this thread will write to. This is just shifted by one -- to make room for the initial element j0 <- case dir of L -> A.add numType i0 (lift 1) R -> A.sub numType i0 (lift 1) -- If this thread has input it participates in the scan let valid i = case dir of L -> A.lt scalarType i sz R -> A.gte scalarType i (lift 0) when (valid i0) $ do x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0 -- Thread 0 of the first segment must also evaluate and store the -- initial element x1 <- if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType seg (lift 0) then do z <- seed writeArray arrOut i0 z case dir of L -> app2 combine z x0 R -> app2 combine x0 z else return x0 -- Block-wide scan n <- A.sub numType sz inf x2 <- if A.gte scalarType n bd then scanBlockSMem dir dev combine Nothing x1 else scanBlockSMem dir dev combine (Just n) x1 -- Write this thread's scan result to memory. Recall that we had to make -- space for the initial element, so the very last thread does not store -- its result here. case dir of L -> when (A.lt scalarType j0 sz) $ writeArray arrOut j0 x2 R -> when (A.gte scalarType j0 (lift 0)) $ writeArray arrOut j0 x2 -- Last active thread writes its result to the partial sums array. These -- will be used to compute the carry-in value in step 2. m <- do x <- A.min scalarType n bd y <- A.sub numType x (lift 1) return y when (A.eq scalarType tid m) $ case dir of L -> writeArray arrTmp seg x2 R -> do x <- A.sub numType end seg y <- A.sub numType x (lift 1) writeArray arrTmp y x2 return_ -- Parallel scan', step 2 -- -- A single thread block performs an inclusive scan of the partial sums array to -- compute the per-block carry-in values, as well as the final reduction result. -- mkScan'AllP2 :: forall aenv e. Elt e => Direction -> DeviceProperties -> Gamma aenv -> IRFun2 PTX aenv (e -> e -> e) -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP2 dir dev aenv combine = let (start, end, paramGang) = gangParam (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) (arrSum, paramSum) = mutableArray ("sum" :: Name (Scalar e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ grid _ _ = 1 gridQ = [|| \_ _ -> 1 ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramSum ++ paramEnv) $ do -- The first and last threads of the block need to communicate the -- block-wide aggregate as a carry-in value across iterations. carry <- staticSharedMem 1 -- A single thread block iterates over the per-block partial results from -- step 1 tid <- threadIdx bd <- blockDim imapFromStepTo start bd end $ \offset -> do i0 <- case dir of L -> A.add numType offset tid R -> do x <- A.sub numType end offset y <- A.sub numType x tid z <- A.sub numType y (lift 1) return z let valid i = case dir of L -> A.lt scalarType i end R -> A.gte scalarType i start when (valid i0) $ do -- wait for the carry-in value to be updated __syncthreads x0 <- readArray arrTmp i0 x1 <- if A.gt scalarType offset (lift 0) `A.land` A.eq scalarType tid (lift 0) then do c <- readArray carry (lift 0 :: IR Int32) case dir of L -> app2 combine c x0 R -> app2 combine x0 c else return x0 n <- A.sub numType end offset x2 <- if A.gte scalarType n bd then scanBlockSMem dir dev combine Nothing x1 else scanBlockSMem dir dev combine (Just n) x1 -- Update the partial results array writeArray arrTmp i0 x2 -- The last active thread saves its result as the carry-out value. m <- do x <- A.min scalarType bd n y <- A.sub numType x (lift 1) return y when (A.eq scalarType tid m) $ writeArray carry (lift 0 :: IR Int32) x2 -- First thread stores the final carry-out values at the final reduction -- result for the entire array __syncthreads when (A.eq scalarType tid (lift 0)) $ writeArray arrSum (lift 0 :: IR Int32) =<< readArray carry (lift 0 :: IR Int32) return_ -- Parallel scan', step 3. -- -- Threads combine every element of the partial block results with the carry-in -- value computed in step 2. -- mkScan'AllP3 :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP3 dir dev aenv combine = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Vector e)) (arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e)) paramEnv = envParam aenv -- stride = local scalarType ("ix.stride" :: Name Int32) paramStride = scalarParameter scalarType ("ix.stride" :: Name Int32) -- config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||] in makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do sz <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut)) tid <- threadIdx when (A.lt scalarType tid stride) $ do bid <- blockIdx gd <- gridDim c0 <- A.add numType start bid imapFromStepTo c0 gd end $ \chunk -> do (inf,sup) <- case dir of L -> do a <- A.add numType chunk (lift 1) b <- A.mul numType stride a c <- A.add numType b (lift 1) d <- A.add numType c stride e <- A.min scalarType d sz return (c,e) R -> do a <- A.sub numType end chunk b <- A.mul numType stride a c <- A.sub numType sz b d <- A.sub numType c (lift 1) e <- A.sub numType d stride f <- A.max scalarType e (lift 0) return (f,d) carry <- case dir of L -> readArray arrTmp chunk R -> do a <- A.add numType chunk (lift 1) b <- readArray arrTmp a return b -- Apply the carry-in value to each element in the chunk bd <- blockDim i0 <- A.add numType inf tid imapFromStepTo i0 bd sup $ \i -> do v <- readArray arrOut i u <- case dir of L -> app2 combine carry v R -> app2 combine v carry writeArray arrOut i u return_ -- Multidimensional scans -- ---------------------- -- Multidimensional scan along the innermost dimension -- -- A thread block individually computes along each innermost dimension. This is -- a single-pass operation. -- -- * We can assume that the array is non-empty; exclusive scans with empty -- innermost dimension will be instead filled with the seed element via -- 'mkScanFill'. -- -- * Small but non-empty innermost dimension arrays (size << thread -- block size) will have many threads which do no work. -- mkScanDim :: forall aenv sh e. (Shape sh, Elt e) => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> Maybe (IRExp PTX aenv e) -- ^ seed element, if this is an exclusive scan -> IRDelayed PTX aenv (Array (sh:.Int) e) -- ^ input data -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e)) mkScanDim dir dev aenv combine mseed IRDelayed{..} = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramEnv) $ do -- The first and last threads of the block need to communicate the -- block-wide aggregate as a carry-in value across iterations. -- -- TODO: we could optimise this a bit if we can get access to the shared -- memory area used by 'scanBlockSMem', and from there directly read the -- value computed by the last thread. carry <- staticSharedMem 1 -- Size of the input array sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent -- Thread blocks iterate over the outer dimensions. Threads in a block -- cooperatively scan along one dimension, but thread blocks do not -- communicate with each other. -- bid <- blockIdx gd <- gridDim s0 <- A.add numType start bid imapFromStepTo s0 gd end $ \seg -> do -- Index this thread reads from tid <- threadIdx i0 <- case dir of L -> do x <- A.mul numType seg sz y <- A.add numType x tid return y R -> do x <- A.add numType seg (lift 1) y <- A.mul numType x sz z <- A.sub numType y tid w <- A.sub numType z (lift 1) return w -- Index this thread writes to j0 <- case mseed of Nothing -> return i0 Just{} -> do szp1 <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut)) case dir of L -> do x <- A.mul numType seg szp1 y <- A.add numType x tid return y R -> do x <- A.add numType seg (lift 1) y <- A.mul numType x szp1 z <- A.sub numType y tid w <- A.sub numType z (lift 1) return w -- Stride indices by block dimension bd <- blockDim let next ix = case dir of L -> A.add numType ix bd R -> A.sub numType ix bd -- Initialise this scan segment -- -- If this is an exclusive scan then the first thread just evaluates the -- seed element and stores this value into the carry-in slot. All threads -- shift their write-to index (j) by one, to make space for this element. -- -- If this is an inclusive scan then do a block-wide scan. The last thread -- in the block writes the carry-in value. -- r <- case mseed of Just seed -> do when (A.eq scalarType tid (lift 0)) $ do z <- seed writeArray arrOut j0 z writeArray carry (lift 0 :: IR Int32) z j1 <- case dir of L -> A.add numType j0 (lift 1) R -> A.sub numType j0 (lift 1) return $ A.trip sz i0 j1 Nothing -> do when (A.lt scalarType tid sz) $ do x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0 r0 <- if A.gte scalarType sz bd then scanBlockSMem dir dev combine Nothing x0 else scanBlockSMem dir dev combine (Just sz) x0 writeArray arrOut j0 r0 ll <- A.sub numType bd (lift 1) when (A.eq scalarType tid ll) $ writeArray carry (lift 0 :: IR Int32) r0 n1 <- A.sub numType sz bd i1 <- next i0 j1 <- next j0 return $ A.trip n1 i1 j1 -- Iterate over the remaining elements in this segment void $ while (\(A.fst3 -> n) -> A.gt scalarType n (lift 0)) (\(A.untrip -> (n,i,j)) -> do -- Wait for the carry-in value from the previous iteration to be updated __syncthreads -- Compute and store the next element of the scan -- -- NOTE: As with 'foldSeg' we require all threads to participate in -- every iteration of the loop otherwise they will die prematurely. -- Out-of-bounds threads return 'undef' at this point, which is really -- unfortunate ): -- x <- if A.lt scalarType tid n then app1 delayedLinearIndex =<< A.fromIntegral integralType numType i else let go :: TupleType a -> Operands a go UnitTuple = OP_Unit go (PairTuple a b) = OP_Pair (go a) (go b) go (SingleTuple t) = ir' t (undef t) in return . IR $ go (eltType (undefined::e)) -- Thread zero incorporates the carry-in element y <- if A.eq scalarType tid (lift 0) then do c <- readArray carry (lift 0 :: IR Int32) case dir of L -> app2 combine c x R -> app2 combine x c else return x -- Perform the scan and write the result to memory z <- if A.gte scalarType n bd then scanBlockSMem dir dev combine Nothing y else scanBlockSMem dir dev combine (Just n) y when (A.lt scalarType tid n) $ do writeArray arrOut j z -- The last thread of the block writes its result as the carry-out -- value. If this thread is not active then we are on the last -- iteration of the loop and it will not be needed. w <- A.sub numType bd (lift 1) when (A.eq scalarType tid w) $ writeArray carry (lift 0 :: IR Int32) z -- Update indices for the next iteration n' <- A.sub numType n bd i' <- next i j' <- next j return $ A.trip n' i' j') r return_ -- Multidimensional scan' along the innermost dimension -- -- A thread block individually computes along each innermost dimension. This is -- a single-pass operation. -- -- * We can assume that the array is non-empty; exclusive scans with empty -- innermost dimension will be instead filled with the seed element via -- 'mkScan'Fill'. -- -- * Small but non-empty innermost dimension arrays (size << thread -- block size) will have many threads which do no work. -- mkScan'Dim :: forall aenv sh e. (Shape sh, Elt e) => Direction -> DeviceProperties -- ^ properties of the target GPU -> Gamma aenv -- ^ array environment -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> IRExp PTX aenv e -- ^ seed element -> IRDelayed PTX aenv (Array (sh:.Int) e) -- ^ input data -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e)) mkScan'Dim dir dev aenv combine seed IRDelayed{..} = let (start, end, paramGang) = gangParam (arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e)) (arrSum, paramSum) = mutableArray ("sum" :: Name (Array sh e)) paramEnv = envParam aenv -- config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||] smem n = warps * (1 + per_warp) * bytes where ws = CUDA.warpSize dev warps = n `P.quot` ws per_warp = ws + ws `P.quot` 2 bytes = sizeOf (eltType (undefined :: e)) in makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do -- The first and last threads of the block need to communicate the -- block-wide aggregate as a carry-in value across iterations. -- -- TODO: we could optimise this a bit if we can get access to the shared -- memory area used by 'scanBlockSMem', and from there directly read the -- value computed by the last thread. carry <- staticSharedMem 1 -- Size of the input array sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent -- If the innermost dimension is smaller than the number of threads in the -- block, those threads will never contribute to the output. tid <- threadIdx when (A.lte scalarType tid sz) $ do -- Thread blocks iterate over the outer dimensions, each thread block -- cooperatively scanning along each outermost index. bid <- blockIdx gd <- gridDim s0 <- A.add numType start bid imapFromStepTo s0 gd end $ \seg -> do -- Not necessary to wait for threads to catch up before starting this segment -- __syncthreads -- Linear index bounds for this segment inf <- A.mul numType seg sz sup <- A.add numType inf sz -- Index that this thread will read from. Recall that the supremum index -- is exclusive. i0 <- case dir of L -> A.add numType inf tid R -> do x <- A.sub numType sup tid y <- A.sub numType x (lift 1) return y -- The index that this thread will write to. This is just shifted along -- by one to make room for the initial element. j0 <- case dir of L -> A.add numType i0 (lift 1) R -> A.sub numType i0 (lift 1) -- Evaluate the initial element. Store it into the carry-in slot as well -- as to the array as the first element. This is always valid because if -- the input array is empty then we will be evaluating via mkScan'Fill. when (A.eq scalarType tid (lift 0)) $ do z <- seed writeArray arrOut i0 z writeArray carry (lift 0 :: IR Int32) z bd <- blockDim let next ix = case dir of L -> A.add numType ix bd R -> A.sub numType ix bd -- Now, threads iterate over the elements along the innermost dimension. -- At each iteration the first thread incorporates the carry-in value -- from the previous step. -- -- The index tracks how many elements remain for the thread block, since -- indices i* and j* are local to each thread n0 <- A.sub numType sup inf void $ while (\(A.fst3 -> n) -> A.gt scalarType n (lift 0)) (\(A.untrip -> (n,i,j)) -> do -- Wait for threads to catch up to ensure the carry-in value from -- the last iteration has been updated __syncthreads -- If all threads in the block will participate this round we can -- avoid (almost) all bounds checks. _ <- if A.gte scalarType n bd -- All threads participate. No bounds checks required but -- the last thread needs to update the carry-in value. then do x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i y <- if A.eq scalarType tid (lift 0) then do c <- readArray carry (lift 0 :: IR Int32) case dir of L -> app2 combine c x R -> app2 combine x c else return x z <- scanBlockSMem dir dev combine Nothing y -- Write results to the output array. Note that if we -- align directly on the boundary of the array this is not -- valid for the last thread. case dir of L -> when (A.lt scalarType j sup) $ writeArray arrOut j z R -> when (A.gte scalarType j inf) $ writeArray arrOut j z -- Last thread of the block also saves its result as the -- carry-in value bd1 <- A.sub numType bd (lift 1) when (A.eq scalarType tid bd1) $ writeArray carry (lift 0 :: IR Int32) z return (IR OP_Unit :: IR ()) -- Only threads that are in bounds can participate. This is -- the last iteration of the loop. The last active thread -- still needs to store its value into the carry-in slot. else do when (A.lt scalarType tid n) $ do x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i y <- if A.eq scalarType tid (lift 0) then do c <- readArray carry (lift 0 :: IR Int32) case dir of L -> app2 combine c x R -> app2 combine x c else return x z <- scanBlockSMem dir dev combine (Just n) y m <- A.sub numType n (lift 1) _ <- if A.lt scalarType tid m then writeArray arrOut j z >> return (IR OP_Unit :: IR ()) else writeArray carry (lift 0 :: IR Int32) z >> return (IR OP_Unit :: IR ()) return () return (IR OP_Unit :: IR ()) A.trip <$> A.sub numType n bd <*> next i <*> next j) (A.trip n0 i0 j0) -- Wait for the carry-in value to be updated __syncthreads -- Store the carry-in value to the separate final results array when (A.eq scalarType tid (lift 0)) $ writeArray arrSum seg =<< readArray carry (lift 0 :: IR Int32) return_ -- Parallel scan, auxiliary -- -- If this is an exclusive scan of an empty array, we just fill the result with -- the seed element. -- mkScanFill :: (Shape sh, Elt e) => PTX -> Gamma aenv -> IRExp PTX aenv e -> CodeGen (IROpenAcc PTX aenv (Array sh e)) mkScanFill ptx aenv seed = mkGenerate ptx aenv (IRFun1 (const seed)) mkScan'Fill :: forall aenv sh e. (Shape sh, Elt e) => PTX -> Gamma aenv -> IRExp PTX aenv e -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e)) mkScan'Fill ptx aenv seed = Safe.coerce <$> (mkGenerate ptx aenv (IRFun1 (const seed)) :: CodeGen (IROpenAcc PTX aenv (Array sh e))) -- Block wide scan -- --------------- -- Efficient block-wide (inclusive) scan using the specified operator. -- -- Each block requires (#warps * (1 + 1.5*warp size)) elements of dynamically -- allocated shared memory. -- -- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/block/specializations/block_scan_warp_scans.cuh -- scanBlockSMem :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target device -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> Maybe (IR Int32) -- ^ number of valid elements (may be less than block size) -> IR e -- ^ calling thread's input element -> CodeGen (IR e) scanBlockSMem dir dev combine nelem = warpScan >=> warpPrefix where int32 :: Integral a => a -> IR Int32 int32 = lift . P.fromIntegral -- Temporary storage required for each warp warp_smem_elems = CUDA.warpSize dev + (CUDA.warpSize dev `P.quot` 2) warp_smem_bytes = warp_smem_elems * sizeOf (eltType (undefined::e)) -- Step 1: Scan in every warp warpScan :: IR e -> CodeGen (IR e) warpScan input = do -- Allocate (1.5 * warpSize) elements of shared memory for each warp -- (individually addressable by each warp) wid <- warpId skip <- A.mul numType wid (int32 warp_smem_bytes) smem <- dynamicSharedMem (int32 warp_smem_elems) skip scanWarpSMem dir dev combine smem input -- Step 2: Collect the aggregate results of each warp to compute the prefix -- values for each warp and combine with the partial result to compute each -- thread's final value. warpPrefix :: IR e -> CodeGen (IR e) warpPrefix input = do -- Allocate #warps elements of shared memory bd <- blockDim warps <- A.quot integralType bd (int32 (CUDA.warpSize dev)) skip <- A.mul numType warps (int32 warp_smem_bytes) smem <- dynamicSharedMem warps skip -- Share warp aggregates wid <- warpId lane <- laneId when (A.eq scalarType lane (int32 (CUDA.warpSize dev - 1))) $ do writeArray smem wid input -- Wait for each warp to finish its local scan and share the aggregate __syncthreads -- Compute the prefix value for this warp and add to the partial result. -- This step is not required for the first warp, which has no carry-in. if A.eq scalarType wid (lift 0) then return input else do -- Every thread sequentially scans the warp aggregates to compute -- their prefix value. We do this sequentially, but could also have -- warp 0 do it cooperatively if we limit thread block sizes to -- (warp size ^ 2). steps <- case nelem of Nothing -> return wid Just n -> A.min scalarType wid =<< A.quot integralType n (int32 (CUDA.warpSize dev)) p0 <- readArray smem (lift 0 :: IR Int32) prefix <- iterFromStepTo (lift 1) (lift 1) steps p0 $ \step x -> do y <- readArray smem step case dir of L -> app2 combine x y R -> app2 combine y x case dir of L -> app2 combine prefix input R -> app2 combine input prefix -- Warp-wide scan -- -------------- -- Efficient warp-wide (inclusive) scan using the specified operator. -- -- Each warp requires 48 (1.5 x warp size) elements of shared memory. The -- routine assumes that it is allocated individually per-warp (i.e. can be -- indexed in the range [0, warp size)). -- -- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/warp/specializations/warp_scan_smem.cuh -- scanWarpSMem :: forall aenv e. Elt e => Direction -> DeviceProperties -- ^ properties of the target device -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> IRArray (Vector e) -- ^ temporary storage array in shared memory (1.5 x warp size elements) -> IR e -- ^ calling thread's input element -> CodeGen (IR e) scanWarpSMem dir dev combine smem = scan 0 where log2 :: Double -> Double log2 = P.logBase 2 -- Number of steps required to scan warp steps = P.floor (log2 (P.fromIntegral (CUDA.warpSize dev))) halfWarp = P.fromIntegral (CUDA.warpSize dev `P.quot` 2) -- Unfold the scan as a recursive code generation function scan :: Int -> IR e -> CodeGen (IR e) scan step x | step >= steps = return x | offset <- 1 `P.shiftL` step = do -- share partial result through shared memory buffer lane <- laneId i <- A.add numType lane (lift halfWarp) writeArray smem i x -- update partial result if in range x' <- if A.gte scalarType lane (lift offset) then do i' <- A.sub numType i (lift offset) -- lane + HALF_WARP - offset x' <- readArray smem i' case dir of L -> app2 combine x' x R -> app2 combine x x' else return x scan (step+1) x'