{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.FoldSeg
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.FoldSeg
  where

import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop                      as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar

import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold                  ( reduceBlockSMem, reduceWarpSMem, imapFromTo )
import Data.Array.Accelerate.LLVM.PTX.Target

import LLVM.AST.Type.Representation

import qualified Foreign.CUDA.Analysis                              as CUDA

import Control.Monad                                                ( void )
import Control.Monad.State                                          ( gets )
import Data.String                                                  ( fromString )
import Prelude                                                      as P


-- Segmented reduction along the innermost dimension of an array. Performs one
-- reduction per segment of the source array.
--
mkFoldSeg
    :: forall aenv sh i e.
       Gamma            aenv
    -> ArrayR (Array (sh, Int) e)
    -> IntegralType i
    -> IRFun2       PTX aenv (e -> e -> e)
    -> Maybe (IRExp PTX aenv e)
    -> MIRDelayed   PTX aenv (Array (sh, Int) e)
    -> MIRDelayed   PTX aenv (Segments i)
    -> CodeGen      PTX      (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSeg :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSeg Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg =
  IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) (IROpenAcc PTX aenv (Array (sh, Int) e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen
     PTX
     (IROpenAcc PTX aenv (Array (sh, Int) e)
      -> IROpenAcc PTX aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg
        CodeGen
  PTX
  (IROpenAcc PTX aenv (Array (sh, Int) e)
   -> IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh i e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp  Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr MIRDelayed PTX aenv (Segments i)
seg


-- This implementation assumes that the segments array represents the offset
-- indices to the source array, rather than the lengths of each segment. The
-- segment-offset approach is required for parallel implementations.
--
-- Each segment is computed by a single thread block, meaning we don't have to
-- worry about inter-block synchronisation.
--
mkFoldSegP_block
    :: forall aenv sh i e.
       Gamma          aenv
    -> ArrayR (Array (sh, Int) e)
    -> IntegralType i
    -> IRFun2     PTX aenv (e -> e -> e)
    -> MIRExp     PTX aenv e
    -> MIRDelayed PTX aenv (Array (sh, Int) e)
    -> MIRDelayed PTX aenv (Segments i)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_block Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr MIRDelayed PTX aenv (Segments i)
mseg = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
      (IRDelayed PTX aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)   = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in"  MIRDelayed PTX aenv (Array (sh, Int) e)
marr
      (IRDelayed PTX aenv (Segments i)
arrSeg, [Parameter]
paramSeg)  = Name (Segments i)
-> MIRDelayed PTX aenv (Segments i)
-> (IRDelayed PTX aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Segments i)
"seg" MIRDelayed PTX aenv (Segments i)
mseg
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.decWarp DeviceProperties
dev) Int -> Int
dsmem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      dsmem :: Int -> Int
dsmem Int
n             = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldSeg_block" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do

    -- We use a dynamically scheduled work queue in order to evenly distribute
    -- the uneven workload, due to the variable length of each segment, over the
    -- available thread blocks.
    -- queue <- globalWorkQueue

    -- All threads in the block need to know what the start and end indices of
    -- this segment are in order to participate in the reduction. We use
    -- variables in __shared__ memory to communicate these values between
    -- threads in the block. Furthermore, by using a 2-element array, we can
    -- have the first two threads of the block read the start and end indices as
    -- a single coalesced read, since they will be sequential in the
    -- segment-offset array.
    --
    IRArray (Vector Int)
smem  <- TypeR Int -> Word64 -> CodeGen PTX (IRArray (Vector Int))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Word64
2

    -- Compute the number of segments and size of the innermost dimension. These
    -- are required if we are reducing a rank-2 or higher array, to properly
    -- compute the start and end indices of the portion of the array this thread
    -- block reduces. Note that this is a segment-offset array computed by
    -- 'scanl (+) 0' of the segment length array, so its size has increased by
    -- one.
    --
    Operands Int
sz    <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn
    Operands Int
ss    <- do Operands Int
n <- Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands ((), Int) -> Operands Int)
-> CodeGen PTX (Operands ((), Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Segments i) -> CodeGen PTX (Operands ((), Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Segments i)
arrSeg
                NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
n (Int -> Operands Int
liftInt Int
1)

    -- Each thread block cooperatively reduces a segment.
    -- s0    <- dequeue queue (lift 1)
    -- for s0 (\s -> A.lt singleType s end) (\_ -> dequeue queue (lift 1)) $ \s -> do

    Operands Int
start <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
    Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut)

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
s -> do

      -- The first two threads of the block determine the indices of the
      -- segments array that we will reduce between and distribute those values
      -- to the other threads in the block.
      Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
2)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands Int
i <- case ShapeR sh
shr of
               ShapeRsnoc ShapeR sh1
ShapeRz -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
s
               ShapeR sh
_ -> IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.rem IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
        Operands Int
j <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
        Operands i
v <- IROpenFun1 PTX () aenv (Int -> i)
-> Operands Int -> IROpenExp PTX ((), Int) aenv i
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Segments i)
-> IROpenFun1 PTX () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Segments i)
arrSeg) Operands Int
j
        IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> Operands Int
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem Operands Int32
tid (Operands Int -> CodeGen PTX ())
-> CodeGen PTX (Operands Int) -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
intTp NumType Int
forall a. IsNum a => NumType a
numType Operands i
v

      -- Once all threads have caught up, begin work on the new segment.
      CodeGen PTX ()
__syncthreads

      Operands Int
u <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem (Int32 -> Operands Int32
liftInt32 Int32
0)
      Operands Int
v <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
smem (Int32 -> Operands Int32
liftInt32 Int32
1)

      -- Determine the index range of the input array we will reduce over.
      -- Necessary for multidimensional segmented reduction.
      (Operands Int
inf,Operands Int
sup) <- Operands (Int, Int) -> (Operands Int, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (Int, Int) -> (Operands Int, Operands Int))
-> CodeGen PTX (Operands (Int, Int))
-> CodeGen PTX (Operands Int, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ShapeR sh
shr of
                                  ShapeRsnoc ShapeR sh1
ShapeRz -> Operands (Int, Int) -> CodeGen PTX (Operands (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair Operands Int
u Operands Int
v)
                                  ShapeR sh
_ -> do Operands Int
q <- IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
                                          Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
q Operands Int
sz
                                          Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair (Operands Int -> Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands Int -> Operands (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
u Operands Int
a
                                                 CodeGen PTX (Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands (Int, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
v Operands Int
a

      CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
        if (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
inf Operands Int
sup)
          -- This segment is empty. If this is an exclusive reduction the
          -- first thread writes out the initial element for this segment.
          then do
            case MIRExp PTX aenv e
mseed of
              MIRExp PTX aenv e
Nothing -> Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
              Just IRExp PTX aenv e
z  -> do
                CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
                Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

          -- This is a non-empty segment.
          else do
            -- Step 1: initialise local sums
            --
            -- NOTE: We require all threads to enter this branch and execute the
            -- first step, even if they do not have a valid element and must
            -- return 'undef'. If we attempt to skip this entire section for
            -- non-participating threads (i.e. 'when (i0 < sup)'), it seems that
            -- those threads die and will not participate in the computation of
            -- _any_ further segment. I'm not sure if this is a CUDA oddity
            -- (e.g. we must have all threads convergent on __syncthreads) or
            -- a bug in NVPTX / ptxas.
            --
            Operands Int
i0 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
            Operands e
x0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i0 Operands Int
sup)
                    then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
                    else let
                             go :: TypeR a -> Operands a
                             go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                             go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                             go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                         in
                         Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

            Operands Int
bd  <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
            Operands Int
v0  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
inf
            Operands Int32
v0' <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v0
            Operands e
r0  <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v0 Operands Int
bd)
                     then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing    Operands e
x0
                     else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v0') Operands e
x0

            -- Step 2: keep walking over the input
            Operands Int
nxt <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
bd
            Operands e
r   <- TypeR e
-> Operands Int
-> Operands Int
-> Operands Int
-> Operands e
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp Operands Int
nxt Operands Int
bd Operands Int
sup Operands e
r0 ((Operands Int -> Operands e -> CodeGen PTX (Operands e))
 -> CodeGen PTX (Operands e))
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset Operands e
r -> do

                     -- Wait for threads to catch up before starting the next stripe
                     CodeGen PTX ()
__syncthreads

                     Operands Int
i' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
tid
                     Operands Int
v' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
offset
                     Operands e
r' <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v' Operands Int
bd)
                             -- All threads in the block are in bounds, so we
                             -- can avoid bounds checks.
                             then do
                               Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
                               Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
Operands e
x
                               Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

                             -- Not all threads are valid. Note that we still
                             -- have all threads enter the reduction procedure
                             -- to avoid thread divergence on synchronisation
                             -- points, similar to the above NOTE.
                             else do
                               Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i' Operands Int
sup)
                                      then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
                                      else let
                                               go :: TypeR a -> Operands a
                                               go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                                               go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                                               go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                                           in
                                           Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

                               Operands Int32
z <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v'
                               Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
z) Operands e
x
                               Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

                     -- first thread incorporates the result from the previous
                     -- iteration
                     if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
                       then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
r Operands e
Operands e
r'
                       else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r'

            -- Step 3: Thread zero writes the aggregate reduction for this
            -- segment to memory. If this is an exclusive fold combine with the
            -- initial element as well.
            CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
             IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
               case MIRExp PTX aenv e
mseed of
                 MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
                 Just IRExp PTX aenv e
z  -> (Operands e -> Operands e -> IRExp PTX aenv e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
Operands e
r (Operands e -> IRExp PTX aenv e)
-> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z  -- Note: initial element on the left

            Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- This implementation assumes that the segments array represents the offset
-- indices to the source array, rather than the lengths of each segment. The
-- segment-offset approach is required for parallel implementations.
--
-- Each segment is computed by a single warp, meaning we don't have to worry
-- about inter- or intra-block synchronisation.
--
mkFoldSegP_warp
    :: forall aenv sh i e.
       Gamma          aenv
    -> ArrayR (Array (sh, Int) e)
    -> IntegralType i
    -> IRFun2     PTX aenv (e -> e -> e)
    -> MIRExp     PTX aenv e
    -> MIRDelayed PTX aenv (Array (sh, Int) e)
    -> MIRDelayed PTX aenv (Segments i)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Segments i)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkFoldSegP_warp Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) IntegralType i
intTp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr MIRDelayed PTX aenv (Segments i)
mseg = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
      (IRDelayed PTX aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)   = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in"  MIRDelayed PTX aenv (Array (sh, Int) e)
marr
      (IRDelayed PTX aenv (Segments i)
arrSeg, [Parameter]
paramSeg)  = Name (Segments i)
-> MIRDelayed PTX aenv (Segments i)
-> (IRDelayed PTX aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Segments i)
"seg" MIRDelayed PTX aenv (Segments i)
mseg
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.decWarp DeviceProperties
dev) Int -> Int
dsmem Int -> Int -> Int
grid Q (TExp (Int -> Int -> Int))
gridQ
      dsmem :: Int -> Int
dsmem Int
n             = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
per_warp_bytes
        where
          warps :: Int
warps           = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
      --
      grid :: Int -> Int -> Int
grid Int
n Int
m            = Int -> Int -> Int
multipleOf Int
n (Int
m Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws)
      gridQ :: Q (TExp (Int -> Int -> Int))
gridQ               = [|| \n m -> $$multipleOfQ n (m `P.quot` ws) ||]
      --
      per_warp_bytes :: Int
per_warp_bytes      = (Int
per_warp_elems Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`P.max` (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp)
      per_warp_elems :: Int
per_warp_elems      = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2)
      ws :: Int
ws                  = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev

      int32 :: Integral a => a -> Operands Int32
      int32 :: a -> Operands Int32
int32 = Int32 -> Operands Int32
liftInt32 (Int32 -> Operands Int32) -> (a -> Int32) -> a -> Operands Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldSeg_warp" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do

    -- Each warp works independently.
    -- Determine the ID of this warp within the thread block.
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int32
wid   <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
tid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
ws)

    -- Number of warps per thread block
    Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
    Operands Int32
wpb   <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
bd (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
ws)

    -- ID of this warp within the grid
    Operands Int32
bid   <- CodeGen PTX (Operands Int32)
blockIdx
    Operands Int32
gwid  <- do Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bid Operands Int32
wpb
                Operands Int32
b <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid Operands Int32
a
                Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
b

    -- All threads in the warp need to know what the start and end indices of
    -- this segment are in order to participate in the reduction. We use
    -- variables in __shared__ memory to communicate these values between
    -- threads. Furthermore, by using a 2-element array, we can have the first
    -- two threads of the warp read the start and end indices as a single
    -- coalesced read, as these elements will be adjacent in the segment-offset
    -- array.
    --
    -- Note that this is aliased with the memory used to communicate reduction
    -- values within the warp.
    --
    IRArray (Vector Int)
lim   <- do
      Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_bytes)
      IRArray (Vector Int)
b <- TypeR Int
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector Int))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) IntegralType Int32
TypeInt32 (Int32 -> Operands Int32
liftInt32 Int32
2) Operands Int32
a
      IRArray (Vector Int) -> CodeGen PTX (IRArray (Vector Int))
forall (m :: * -> *) a. Monad m => a -> m a
return IRArray (Vector Int)
b

    -- Allocate (1.5 * warpSize) elements of shared memory for each warp to
    -- communicate reduction values.
    --
    -- Note that this is aliased with the memory used to communicate the start
    -- and end indices of this segment.
    --
    IRArray (Vector e)
smem  <- do
      Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_bytes)
      IRArray (Vector e)
b <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
per_warp_elems) Operands Int32
a
      IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return IRArray (Vector e)
b

    -- Compute the number of segments and size of the innermost dimension. These
    -- are required if we are reducing a rank-2 or higher array, to properly
    -- compute the start and end indices of the portion of the array this warp
    -- reduces. Note that this is a segment-offset array computed by 'scanl (+) 0'
    -- of the segment length array, so its size has increased by one.
    --
    Operands Int
sz    <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn
    Operands Int
ss    <- do Operands Int
a <- Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands ((), Int) -> Operands Int)
-> CodeGen PTX (Operands ((), Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Segments i) -> CodeGen PTX (Operands ((), Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Segments i)
arrSeg
                Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
a (Int -> Operands Int
liftInt Int
1)
                Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
b

    -- Each thread reduces a segment independently
    Operands Int
s0    <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
gwid
    Operands Int
gd    <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
    Operands Int
wpb'  <- Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
wpb
    Operands Int
step  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
wpb' Operands Int
gd
    Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut)
    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
s0 Operands Int
step Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
s -> do

      CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp

      -- The first two threads of the warp determine the indices of the segments
      -- array that we will reduce between and distribute those values to the
      -- other threads in the warp
      Operands Int32
lane <- CodeGen PTX (Operands Int32)
laneId
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
2)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands Int
a <- case ShapeR sh
shr of
               ShapeRsnoc ShapeR sh1
ShapeRz -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
s
               ShapeR sh
_ -> IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.rem IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
        Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
a (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
        Operands i
c <- IROpenFun1 PTX () aenv (Int -> i)
-> Operands Int -> IROpenExp PTX ((), Int) aenv i
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Segments i)
-> IROpenFun1 PTX () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Segments i)
arrSeg) Operands Int
b
        IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> Operands Int
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim Operands Int32
lane (Operands Int -> CodeGen PTX ())
-> CodeGen PTX (Operands Int) -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
intTp NumType Int
forall a. IsNum a => NumType a
numType Operands i
c

      CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp

      -- Determine the index range of the input array we will reduce over.
      -- Necessary for multidimensional segmented reduction.
      (Operands Int
inf,Operands Int
sup) <- do
        Operands Int
u <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim (Int32 -> Operands Int32
liftInt32 Int32
0)
        Operands Int
v <- IntegralType Int32
-> IRArray (Vector Int)
-> Operands Int32
-> CodeGen PTX (Operands Int)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector Int)
lim (Int32 -> Operands Int32
liftInt32 Int32
1)
        Operands (Int, Int) -> (Operands Int, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (Int, Int) -> (Operands Int, Operands Int))
-> CodeGen PTX (Operands (Int, Int))
-> CodeGen PTX (Operands Int, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ShapeR sh
shr of
                       ShapeRsnoc ShapeR sh1
ShapeRz -> Operands (Int, Int) -> CodeGen PTX (Operands (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair Operands Int
u Operands Int
v)
                       ShapeR sh
_ -> do Operands Int
q <- IntegralType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
s Operands Int
ss
                               Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
q Operands Int
sz
                               Operands Int -> Operands Int -> Operands (Int, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
A.pair (Operands Int -> Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands Int -> Operands (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
u Operands Int
a
                                      CodeGen PTX (Operands Int -> Operands (Int, Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands (Int, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
v Operands Int
a

      CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp

      CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
        if (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
inf Operands Int
sup)
          -- This segment is empty. If this is an exclusive reduction the first
          -- lane writes out the initial element for this segment.
          then do
            case MIRExp PTX aenv e
mseed of
              MIRExp PTX aenv e
Nothing -> Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())
              Just IRExp PTX aenv e
z  -> do
                CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z
                Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

          -- This is a non-empty segment.
          else do
            -- Step 1: initialise local sums
            --
            -- See comment above why we initialise the loop in this way
            --
            Operands Int
i0  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
            Operands e
x0  <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i0 Operands Int
sup)
                     then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
                     else let
                              go :: TypeR a -> Operands a
                              go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                              go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                              go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                          in
                          Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

            Operands Int
v0  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
inf
            Operands Int32
v0' <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v0
            Operands e
r0  <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v0 (Int -> Operands Int
liftInt Int
ws))
                     then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing    Operands e
x0
                     else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v0') Operands e
x0

            -- Step 2: Keep walking over the rest of the segment
            Operands Int
nx  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf (Int -> Operands Int
liftInt Int
ws)
            Operands e
r   <- TypeR e
-> Operands Int
-> Operands Int
-> Operands Int
-> Operands e
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp Operands Int
nx (Int -> Operands Int
liftInt Int
ws) Operands Int
sup Operands e
r0 ((Operands Int -> Operands e -> CodeGen PTX (Operands e))
 -> CodeGen PTX (Operands e))
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset Operands e
r -> do

                    -- __syncwarp
                    CodeGen PTX ()
__syncthreads -- TLM: why is this necessary?

                    Operands Int
i' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
forall i. IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int Operands Int32
lane
                    Operands Int
v' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
offset
                    Operands e
r' <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v' (Int -> Operands Int
liftInt Int
ws))
                            then do
                              -- All lanes are in bounds, so avoid bounds checks
                              Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
                              Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
Operands e
x
                              Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

                            else do
                              Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i' Operands Int
sup)
                                     then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i'
                                     else let
                                              go :: TypeR a -> Operands a
                                              go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                                              go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                                              go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                                          in
                                          Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

                              Operands Int32
z <- Operands Int -> CodeGen PTX (Operands Int32)
forall i.
IsIntegral i =>
Operands i -> CodeGen PTX (Operands Int32)
i32 Operands Int
v'
                              Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
z) Operands e
x
                              Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

                    -- The first lane incorporates the result from the previous
                    -- iteration
                    if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0))
                      then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
r Operands e
Operands e
r'
                      else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r'

            -- Step 3: Lane zero writes the aggregate reduction for this
            -- segment to memory. If this is an exclusive reduction, also
            -- combine with the initial element
            CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
              IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
s (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                case MIRExp PTX aenv e
mseed of
                  MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
                  Just IRExp PTX aenv e
z  -> (Operands e -> Operands e -> IRExp PTX aenv e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
Operands e
r (Operands e -> IRExp PTX aenv e)
-> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IRExp PTX aenv e
z    -- Note: initial element on the left

            Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


i32 :: IsIntegral i => Operands i -> CodeGen PTX (Operands Int32)
i32 :: Operands i -> CodeGen PTX (Operands Int32)
i32 = IntegralType i
-> NumType Int32 -> Operands i -> CodeGen PTX (Operands Int32)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
forall a. IsIntegral a => IntegralType a
integralType NumType Int32
forall a. IsNum a => NumType a
numType

int :: IsIntegral i => Operands i -> CodeGen PTX (Operands Int)
int :: Operands i -> CodeGen PTX (Operands Int)
int = IntegralType i
-> NumType Int -> Operands i -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
forall a. IsIntegral a => IntegralType a
integralType NumType Int
forall a. IsNum a => NumType a
numType