{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold
  where

import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape                   hiding ( size )
import Data.Array.Accelerate.Representation.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop                      as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar

import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Target

import LLVM.AST.Type.Representation

import qualified Foreign.CUDA.Analysis                              as CUDA

import Control.Monad                                                ( (>=>) )
import Control.Monad.State                                          ( gets )
import Data.String                                                  ( fromString )
import Data.Bits                                                    as P
import Prelude                                                      as P


-- Reduce an array along the innermost dimension. The reduction function must be
-- associative to allow for an efficient parallel implementation, but the
-- initial element does /not/ need to be a neutral element of operator.
--
-- TODO: Specialise for commutative operations (such as (+)) and those with
--       a neutral element {(+), 0}
--
mkFold
    :: forall aenv sh e.
       Gamma            aenv
    -> ArrayR (Array sh e)
    -> IRFun2       PTX aenv (e -> e -> e)
    -> Maybe (IRExp PTX aenv e)
    -> MIRDelayed   PTX aenv (Array (sh, Int) e)
    -> CodeGen      PTX      (IROpenAcc PTX aenv (Array sh e))
mkFold :: Gamma aenv
-> ArrayR (Array sh e)
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFold Gamma aenv
aenv ArrayR (Array sh e)
repr IRFun2 PTX aenv (e -> e -> e)
f Maybe (IRExp PTX aenv e)
z MIRDelayed PTX aenv (Array (sh, Int) e)
acc = case Maybe (IRExp PTX aenv e)
z of
  Just IRExp PTX aenv e
z' -> IROpenAcc PTX aenv (Array sh e)
-> IROpenAcc PTX aenv (Array sh e)
-> IROpenAcc PTX aenv (Array sh e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) (IROpenAcc PTX aenv (Array sh e)
 -> IROpenAcc PTX aenv (Array sh e)
 -> IROpenAcc PTX aenv (Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
-> CodeGen
     PTX
     (IROpenAcc PTX aenv (Array sh e)
      -> IROpenAcc PTX aenv (Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
codeFold CodeGen
  PTX
  (IROpenAcc PTX aenv (Array sh e)
   -> IROpenAcc PTX aenv (Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gamma aenv
-> ArrayR (Array sh e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFoldFill Gamma aenv
aenv ArrayR (Array sh e)
repr IRExp PTX aenv e
z'
  Maybe (IRExp PTX aenv e)
Nothing -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
codeFold
  where
    codeFold :: CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
codeFold = case ArrayR (Array sh e)
repr of
      ArrayR ShapeR sh
ShapeRz TypeR e
tp -> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv e.
Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAll Gamma aenv
aenv TypeR e
tp   IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
f Maybe (IRExp PTX aenv e)
MIRExp PTX aenv e
z MIRDelayed PTX aenv (Array (sh, Int) e)
MIRDelayed PTX aenv (Vector e)
acc
      ArrayR (Array sh e)
_                 -> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFoldDim Gamma aenv
aenv ArrayR (Array sh e)
repr IRFun2 PTX aenv (e -> e -> e)
f Maybe (IRExp PTX aenv e)
z MIRDelayed PTX aenv (Array (sh, Int) e)
acc


-- Reduce an array to a single element.
--
-- Since reductions consume arrays that have been fused into them, parallel
-- reduction requires two separate kernels. At an example, take vector dot
-- product:
--
-- > dotp xs ys = fold (+) 0 (zipWith (*) xs ys)
--
-- 1. The first pass reads in the fused array data, in this case corresponding
--    to the function (\i -> (xs!i) * (ys!i)).
--
-- 2. The second pass reads in the manifest array data from the first step and
--    directly reduces the array. This can be done recursively in-place until
--    only a single element remains.
--
-- In both phases, thread blocks cooperatively reduce a stripe of the input (one
-- element per thread) to a single element, which is stored to the output array.
--
mkFoldAll
    :: forall aenv e.
       Gamma          aenv                      -- ^ array environment
    -> TypeR e
    -> IRFun2     PTX aenv (e -> e -> e)        -- ^ combination function
    -> MIRExp     PTX aenv e                    -- ^ (optional) initial element for exclusive reductions
    -> MIRDelayed PTX aenv (Vector e)           -- ^ input data
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Scalar e))
mkFoldAll :: Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAll Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Vector e)
macc = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  (IROpenAcc PTX aenv (Scalar e)
 -> IROpenAcc PTX aenv (Scalar e) -> IROpenAcc PTX aenv (Scalar e))
-> [IROpenAcc PTX aenv (Scalar e)] -> IROpenAcc PTX aenv (Scalar e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc PTX aenv (Scalar e)
-> IROpenAcc PTX aenv (Scalar e) -> IROpenAcc PTX aenv (Scalar e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) ([IROpenAcc PTX aenv (Scalar e)] -> IROpenAcc PTX aenv (Scalar e))
-> CodeGen PTX [IROpenAcc PTX aenv (Scalar e)]
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen PTX (IROpenAcc PTX aenv (Scalar e))]
-> CodeGen PTX [IROpenAcc PTX aenv (Scalar e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv e.
DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllS  DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Vector e)
macc
                            , DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv e.
DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM1 DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine       MIRDelayed PTX aenv (Vector e)
macc
                            , DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv e.
DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM2 DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed
                            ]


-- Reduction to an array to a single element, for small arrays which can be
-- processed by a single thread block.
--
mkFoldAllS
    :: forall aenv e.
       DeviceProperties                         -- ^ properties of the target GPU
    -> Gamma          aenv                      -- ^ array environment
    -> TypeR e
    -> IRFun2     PTX aenv (e -> e -> e)        -- ^ combination function
    -> MIRExp     PTX aenv e                    -- ^ (optional) initial element for exclusive reductions
    -> MIRDelayed PTX aenv (Vector e)           -- ^ input data
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Scalar e))
mkFoldAllS :: DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllS DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Vector e)
marr =
  let
      (IRArray (Scalar e)
arrOut, [Parameter]
paramOut)  = ArrayR (Scalar e)
-> Name (Scalar e) -> (IRArray (Scalar e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM0 -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM0
dim0 TypeR e
tp) Name (Scalar e)
"out"
      (IRDelayed PTX aenv (Vector e)
arrIn,  [Parameter]
paramIn)   = Name (Vector e)
-> MIRDelayed PTX aenv (Vector e)
-> (IRDelayed PTX aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed PTX aenv (Vector e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
multipleOf Q (TExp (Int -> Int -> Int))
multipleOfQ
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  in
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldAllS" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e)))
-> CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int32
tid     <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int32
bd      <- CodeGen PTX (Operands Int32)
blockDim

    Operands DIM1
sh      <- IRDelayed PTX aenv (Vector e) -> IRExp PTX aenv DIM1
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Vector e)
arrIn
    Operands Int
end     <- ShapeR DIM1 -> Operands DIM1 -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR DIM1
dim1 Operands DIM1
sh

    -- We can assume that there is only a single thread block
    Operands Int32
start'  <- Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int32 -> Operands Int32
liftInt32 Int32
0)
    Operands Int32
end'    <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
end
    Operands Int32
i0      <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
start' Operands Int32
tid
    Operands Int32
sz      <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
end' Operands Int32
start'
    CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
i0 Operands Int32
sz) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ do

      -- Thread reads initial element and then participates in block-wide
      -- reduction.
      Operands e
x0 <- IROpenFun1 PTX DIM0 aenv (Int -> e)
-> Operands Int -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Vector e)
-> IROpenFun1 PTX DIM0 aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Vector e)
arrIn) (Operands Int -> IROpenExp PTX DIM1 aenv e)
-> CodeGen PTX (Operands Int) -> IROpenExp PTX DIM1 aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
i0
      Operands e
r0 <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
sz Operands Int32
bd)
              then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IROpenExp PTX DIM1 aenv e
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
x0
              else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IROpenExp PTX DIM1 aenv e
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
sz) Operands e
x0

      CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$
        IntegralType Int32
-> IRArray (Scalar e)
-> Operands Int32
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int32
TypeInt32 IRArray (Scalar e)
arrOut Operands Int32
tid (Operands e -> CodeGen PTX DIM0)
-> IROpenExp PTX DIM1 aenv e -> CodeGen PTX DIM0
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
          case MIRExp PTX aenv e
mseed of
            MIRExp PTX aenv e
Nothing -> Operands e -> IROpenExp PTX DIM1 aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r0
            Just IROpenExp PTX DIM1 aenv e
z  -> (Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e)
-> Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
r0 (Operands e -> IROpenExp PTX DIM1 aenv e)
-> IROpenExp PTX DIM1 aenv e -> IROpenExp PTX DIM1 aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenExp PTX DIM1 aenv e
z   -- Note: initial element on the left

    CodeGen PTX DIM0
forall arch. HasCallStack => CodeGen arch DIM0
return_


-- Reduction of an entire array to a single element. This kernel implements step
-- one for reducing large arrays which must be processed by multiple thread
-- blocks.
--
mkFoldAllM1
    :: forall aenv e.
       DeviceProperties                         -- ^ properties of the target GPU
    -> Gamma          aenv                      -- ^ array environment
    -> TypeR e
    -> IRFun2     PTX aenv (e -> e -> e)        -- ^ combination function
    -> MIRDelayed PTX aenv (Vector e)           -- ^ input data
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM1 :: DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM1 DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRDelayed PTX aenv (Vector e)
marr =
  let
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      (IRDelayed PTX aenv (Vector e)
arrIn,  [Parameter]
paramIn)   = Name (Vector e)
-> MIRDelayed PTX aenv (Vector e)
-> (IRDelayed PTX aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed PTX aenv (Vector e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start               = Int -> Operands Int
liftInt Int
0
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  in
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldAllM1" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e)))
-> CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- Each thread block cooperatively reduces a stripe of the input and stores
    -- that value into a temporary array at a corresponding index. Since the
    -- order of operations remains fixed, this method supports non-commutative
    -- reductions.
    --
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int
bd    <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
    Operands Int
sz    <- Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands DIM1 -> Operands Int)
-> IRExp PTX aenv DIM1 -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Vector e) -> IRExp PTX aenv DIM1
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Vector e)
arrIn
    Operands Int
end   <- ShapeR DIM1 -> Operands DIM1 -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR DIM1
dim1 (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0)
-> (Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

      -- Wait for all threads to catch up before beginning the stripe
      CodeGen PTX DIM0
__syncthreads

      -- Bounds of the input array we will reduce between
      Operands Int
from  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
seg  Operands Int
bd
      Operands Int
step  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
from Operands Int
bd
      Operands Int
to    <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
sz   Operands Int
step

      -- Threads cooperatively reduce this stripe
      DeviceProperties
-> TypeR e
-> Operands Int
-> Operands Int
-> IRFun2 PTX aenv (e -> e -> e)
-> (Operands Int -> CodeGen PTX (Operands e))
-> (Operands e -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
forall a aenv.
DeviceProperties
-> TypeR a
-> Operands Int
-> Operands Int
-> IRFun2 PTX aenv (a -> a -> a)
-> (Operands Int -> CodeGen PTX (Operands a))
-> (Operands a -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
reduceFromTo DeviceProperties
dev TypeR e
tp Operands Int
from Operands Int
to IRFun2 PTX aenv (e -> e -> e)
combine
        (IROpenFun1 PTX DIM0 aenv (Int -> e)
-> Operands Int -> CodeGen PTX (Operands e)
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Vector e)
-> IROpenFun1 PTX DIM0 aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Vector e)
arrIn))
        (CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> (Operands e -> CodeGen PTX DIM0)
-> Operands e
-> CodeGen PTX DIM0
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
seg)

    CodeGen PTX DIM0
forall arch. HasCallStack => CodeGen arch DIM0
return_


-- Reduction of an array to a single element, (recursive) step 2 of multi-block
-- reduction algorithm.
--
mkFoldAllM2
    :: forall aenv e.
       DeviceProperties
    -> Gamma       aenv
    -> TypeR e
    -> IRFun2  PTX aenv (e -> e -> e)
    -> MIRExp  PTX aenv e
    -> CodeGen PTX      (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM2 :: DeviceProperties
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM2 DeviceProperties
dev Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed =
  let
      (IRArray (Array DIM1 e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Array DIM1 e)
-> Name (Array DIM1 e) -> (IRArray (Array DIM1 e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Array DIM1 e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Array DIM1 e)
"tmp"
      (IRArray (Array DIM1 e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array DIM1 e)
-> Name (Array DIM1 e) -> (IRArray (Array DIM1 e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Array DIM1 e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Array DIM1 e)
"out"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start               = Int -> Operands Int
liftInt Int
0
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  in
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"foldAllM2" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e)))
-> CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- Threads cooperatively reduce a stripe of the input (temporary) array
    -- output from the first phase, storing the results into another temporary.
    -- When only a single thread block remains, we have reached the final
    -- reduction step and add the initial element (for exclusive reductions).
    --
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int32
gd    <- CodeGen PTX (Operands Int32)
gridDim
    Operands Int
bd    <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
    Operands Int
sz    <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> CodeGen PTX (Operands Int))
-> Operands Int -> CodeGen PTX (Operands Int)
forall a b. (a -> b) -> a -> b
$ Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Array DIM1 e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array DIM1 e)
arrTmp)
    Operands Int
end   <- ShapeR DIM1 -> Operands DIM1 -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR DIM1
dim1 (IRArray (Array DIM1 e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array DIM1 e)
arrOut)

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0)
-> (Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

      -- Wait for all threads to catch up before beginning the stripe
      CodeGen PTX DIM0
__syncthreads

      -- Bounds of the input we will reduce between
      Operands Int
from  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
seg  Operands Int
bd
      Operands Int
step  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
from Operands Int
bd
      Operands Int
to    <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
sz   Operands Int
step

      -- Threads cooperatively reduce this stripe
      DeviceProperties
-> TypeR e
-> Operands Int
-> Operands Int
-> IRFun2 PTX aenv (e -> e -> e)
-> (Operands Int -> CodeGen PTX (Operands e))
-> (Operands e -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
forall a aenv.
DeviceProperties
-> TypeR a
-> Operands Int
-> Operands Int
-> IRFun2 PTX aenv (a -> a -> a)
-> (Operands Int -> CodeGen PTX (Operands a))
-> (Operands a -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
reduceFromTo DeviceProperties
dev TypeR e
tp Operands Int
from Operands Int
to IRFun2 PTX aenv (e -> e -> e)
combine (IntegralType Int
-> IRArray (Array DIM1 e)
-> Operands Int
-> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Array DIM1 e)
arrTmp) ((Operands e -> CodeGen PTX DIM0) -> CodeGen PTX DIM0)
-> (Operands e -> CodeGen PTX DIM0) -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ \Operands e
r ->
        CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$
          IntegralType Int
-> IRArray (Array DIM1 e)
-> Operands Int
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int
TypeInt IRArray (Array DIM1 e)
arrOut Operands Int
seg (Operands e -> CodeGen PTX DIM0)
-> CodeGen PTX (Operands e) -> CodeGen PTX DIM0
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
            case MIRExp PTX aenv e
mseed of
              MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
              Just CodeGen PTX (Operands e)
z  -> if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
gd (Int32 -> Operands Int32
liftInt32 Int32
1))
                           then (Operands e -> Operands e -> CodeGen PTX (Operands e))
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
r (Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e) -> CodeGen PTX (Operands e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands e)
z   -- Note: initial element on the left
                           else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r

    CodeGen PTX DIM0
forall arch. HasCallStack => CodeGen arch DIM0
return_


-- Reduce an array of arbitrary rank along the innermost dimension only.
--
-- For simplicity, each element of the output (reduction along an
-- innermost-dimension index) is computed by a single thread block, meaning we
-- don't have to worry about inter-block synchronisation. A more balanced method
-- would be a segmented reduction (specialised, since the length of each segment
-- is known a priori).
--
mkFoldDim
    :: forall aenv sh e.
       Gamma aenv                                     -- ^ array environment
    -> ArrayR (Array sh e)
    -> IRFun2     PTX aenv (e -> e -> e)              -- ^ combination function
    -> MIRExp     PTX aenv e                          -- ^ (optional) seed element, if this is an exclusive reduction
    -> MIRDelayed PTX aenv (Array (sh, Int) e)        -- ^ input data
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array sh e))
mkFoldDim :: Gamma aenv
-> ArrayR (Array sh e)
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFoldDim Gamma aenv
aenv repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Array sh e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh e)
repr Name (Array sh e)
"out"
      (IRDelayed PTX aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)   = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed PTX aenv (Array (sh, Int) e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"fold" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX DIM0 -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e)))
-> CodeGen PTX DIM0
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    -- If the innermost dimension is smaller than the number of threads in the
    -- block, those threads will never contribute to the output.
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int
sz    <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn
    Operands Int32
sz'   <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
sz

    CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
sz') (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ do

      Operands Int
start <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
      Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrOut)

      -- Thread blocks iterate over the outer dimensions, each thread block
      -- cooperatively reducing along each outermost index to a single value.
      --
      Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0)
-> (Operands Int -> CodeGen PTX DIM0) -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

        -- Wait for threads to catch up before starting this segment. We could
        -- also place this at the bottom of the loop, but here allows threads to
        -- exit quickly on the last iteration.
        CodeGen PTX DIM0
__syncthreads

        -- Step 1: initialise local sums
        Operands Int
from  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg  Operands Int
sz          -- first linear index this block will reduce
        Operands Int
to    <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
from Operands Int
sz          -- last linear index this block will reduce (exclusive)

        Operands Int
i0    <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
from (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
        Operands e
x0    <- IROpenFun1 PTX DIM0 aenv (Int -> e)
-> Operands Int -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX DIM0 aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
        Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
        Operands e
r0    <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
sz' Operands Int32
bd)
                   then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing    Operands e
Operands e
x0
                   else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
sz') Operands e
Operands e
x0

        -- Step 2: keep walking over the input
        Operands Int
bd'   <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bd
        Operands Int
next  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
from Operands Int
bd'
        Operands e
r     <- TypeR e
-> Operands Int
-> Operands Int
-> Operands Int
-> Operands e
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp Operands Int
next Operands Int
bd' Operands Int
to Operands e
r0 ((Operands Int -> Operands e -> CodeGen PTX (Operands e))
 -> CodeGen PTX (Operands e))
-> (Operands Int -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset Operands e
r -> do

          -- Wait for all threads to catch up before starting the next stripe
          CodeGen PTX DIM0
__syncthreads

          -- Threads cooperatively reduce this stripe of the input
          Operands Int
i   <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset (Operands Int -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
          Operands Int
v'  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
to Operands Int
offset
          Operands e
r'  <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
v' Operands Int
bd')
                   -- All threads of the block are valid, so we can avoid
                   -- bounds checks.
                   then do
                     Operands e
x <- IROpenFun1 PTX DIM0 aenv (Int -> e)
-> Operands Int -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX DIM0 aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i
                     Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
Operands e
x
                     Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

                   -- Otherwise, we require bounds checks when reading the input
                   -- and during the reduction. Note that even though only the
                   -- valid threads will contribute useful work in the
                   -- reduction, we must still have all threads enter the
                   -- reduction procedure to avoid synchronisation divergence.
                   else do
                     Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
to)
                            then IROpenFun1 PTX DIM0 aenv (Int -> e)
-> Operands Int -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX DIM0 aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i
                            else let
                                     go :: TypeR a -> Operands a
                                     go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands DIM0
OP_Unit
                                     go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                                     go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                                 in
                                 Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

                     Operands Int32
v <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
v'
                     Operands e
y <- DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v) Operands e
x
                     Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y

          if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
            then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
r Operands e
Operands e
r'
            else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r'

        -- Step 3: Thread 0 writes the aggregate reduction of this dimension to
        -- memory. If this is an exclusive fold, combine with the initial element.
        --
        CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$
          IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrOut Operands Int
seg (Operands e -> CodeGen PTX DIM0)
-> IROpenExp PTX DIM1 aenv e -> CodeGen PTX DIM0
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
            case MIRExp PTX aenv e
mseed of
              MIRExp PTX aenv e
Nothing -> Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
r
              Just IROpenExp PTX DIM1 aenv e
z  -> (Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e)
-> Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IROpenExp PTX DIM1 aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine) Operands e
Operands e
r (Operands e -> IROpenExp PTX DIM1 aenv e)
-> IROpenExp PTX DIM1 aenv e -> IROpenExp PTX DIM1 aenv e
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenExp PTX DIM1 aenv e
z  -- Note: initial element on the left

    CodeGen PTX DIM0
forall arch. HasCallStack => CodeGen arch DIM0
return_


-- Exclusive reductions over empty arrays (of any dimension) fill the lower
-- dimensions with the initial element.
--
mkFoldFill
    :: Gamma       aenv
    -> ArrayR (Array sh e)
    -> IRExp   PTX aenv e
    -> CodeGen PTX      (IROpenAcc PTX aenv (Array sh e))
mkFoldFill :: Gamma aenv
-> ArrayR (Array sh e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFoldFill Gamma aenv
aenv ArrayR (Array sh e)
repr IRExp PTX aenv e
seed =
  Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkGenerate Gamma aenv
aenv ArrayR (Array sh e)
repr ((Operands sh -> IRExp PTX aenv e) -> IRFun1 PTX aenv (sh -> e)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 (IRExp PTX aenv e -> Operands sh -> IRExp PTX aenv e
forall a b. a -> b -> a
const IRExp PTX aenv e
seed))


-- Efficient threadblock-wide reduction using the specified operator. The
-- aggregate reduction value is stored in thread zero. Supports non-commutative
-- operators.
--
-- Requires dynamically allocated memory: (#warps * (1 + 1.5 * warp size)).
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.2/cub/block/specializations/block_reduce_warp_reductions.cuh
--
reduceBlockSMem
    :: forall aenv e.
       DeviceProperties                         -- ^ properties of the target device
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)            -- ^ combination function
    -> Maybe (Operands Int32)                         -- ^ number of valid elements (may be less than block size)
    -> Operands e                                     -- ^ calling thread's input element
    -> CodeGen PTX (Operands e)                       -- ^ thread-block-wide reduction using the specified operator (lane 0 only)
reduceBlockSMem :: DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
size = Operands e -> CodeGen PTX (Operands e)
warpReduce (Operands e -> CodeGen PTX (Operands e))
-> (Operands e -> CodeGen PTX (Operands e))
-> Operands e
-> CodeGen PTX (Operands e)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Operands e -> CodeGen PTX (Operands e)
warpAggregate
  where
    int32 :: Integral a => a -> Operands Int32
    int32 :: a -> Operands Int32
int32 = Int32 -> Operands Int32
liftInt32 (Int32 -> Operands Int32) -> (a -> Int32) -> a -> Operands Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral

    -- Temporary storage required for each warp
    bytes :: Int
bytes           = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
    warp_smem_elems :: Int
warp_smem_elems = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2)

    -- Step 1: Reduction in every warp
    --
    warpReduce :: Operands e -> CodeGen PTX (Operands e)
    warpReduce :: Operands e -> CodeGen PTX (Operands e)
warpReduce Operands e
input = do
      -- Allocate (1.5 * warpSize) elements of shared memory for each warp
      Operands Int32
wid   <- CodeGen PTX (Operands Int32)
warpId
      Operands Int32
skip  <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (Int
warp_smem_elems Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes))
      IRArray (Vector e)
smem  <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
warp_smem_elems) Operands Int32
skip

      -- Are we doing bounds checking for this warp?
      --
      case Maybe (Operands Int32)
size of
        -- The entire thread block is valid, so skip bounds checks.
        Maybe (Operands Int32)
Nothing ->
          DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
input

        -- Otherwise check how many elements are valid for this warp. If it is
        -- full then we can still skip bounds checks for it.
        Just Operands Int32
n -> do
          Operands Int32
offset <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev))
          Operands Int32
valid  <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
n Operands Int32
offset
          if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
valid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev)))
            then DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
forall a. Maybe a
Nothing      Operands e
input
            else DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
valid) Operands e
input

    -- Step 2: Aggregate per-warp reductions
    --
    warpAggregate :: Operands e -> CodeGen PTX (Operands e)
    warpAggregate :: Operands e -> CodeGen PTX (Operands e)
warpAggregate Operands e
input = do
      -- Allocate #warps elements of shared memory
      Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
      Operands Int32
warps <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
bd (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev))
      Operands Int32
skip  <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
warps (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (Int
warp_smem_elems Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes))
      IRArray (Vector e)
smem  <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 Operands Int32
warps Operands Int32
skip

      -- Share the per-lane aggregates
      Operands Int32
wid   <- CodeGen PTX (Operands Int32)
warpId
      Operands Int32
lane  <- CodeGen PTX (Operands Int32)
laneId
      CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ do
        IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
wid Operands e
input

      -- Wait for each warp to finish its local reduction
      CodeGen PTX DIM0
__syncthreads

      -- Update the total aggregate. Thread 0 just does this sequentially (as is
      -- done in CUB), but we could also do this cooperatively (better for
      -- larger thread blocks?)
      Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
      if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
        then do
          Operands Int32
steps <- case Maybe (Operands Int32)
size of
                     Maybe (Operands Int32)
Nothing -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
warps
                     Just Operands Int32
n  -> do
                       Operands Int32
a <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
n (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
                       Operands Int32
b <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
a (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev))
                       Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
b
          TypeR e
-> Operands Int32
-> Operands Int32
-> Operands Int32
-> Operands e
-> (Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp (Int32 -> Operands Int32
liftInt32 Int32
1) (Int32 -> Operands Int32
liftInt32 Int32
1) Operands Int32
steps Operands e
input ((Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
 -> CodeGen PTX (Operands e))
-> (Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int32
step Operands e
x ->
            IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x (Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e) -> CodeGen PTX (Operands e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
step
        else
          Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
input


-- Efficient warp-wide reduction using shared memory. The aggregate reduction
-- value for the warp is stored in thread lane zero.
--
-- Each warp requires 48 (1.5 x warp size) elements of shared memory. The
-- routine assumes that is is allocated individually per-warp (i.e. can be
-- indexed in the range [0,warp size)).
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.2/cub/warp/specializations/warp_reduce_smem.cuh#L128
--
reduceWarpSMem
    :: forall aenv e.
       DeviceProperties                         -- ^ properties of the target device
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)            -- ^ combination function
    -> IRArray (Vector e)                       -- ^ temporary storage array in shared memory (1.5 warp size elements)
    -> Maybe (Operands Int32)                         -- ^ number of items that will be reduced by this warp, otherwise all lanes are valid
    -> Operands e                                     -- ^ calling thread's input element
    -> CodeGen PTX (Operands e)                       -- ^ warp-wide reduction using the specified operator (lane 0 only)
reduceWarpSMem :: DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceWarpSMem DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Maybe (Operands Int32)
size = Int -> Operands e -> CodeGen PTX (Operands e)
reduce Int
0
  where
    log2 :: Double -> Double
    log2 :: Double -> Double
log2  = Double -> Double -> Double
forall a. Floating a => a -> a -> a
P.logBase Double
2

    -- Number steps required to reduce warp
    steps :: Int
steps = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
P.floor (Double -> Int)
-> (DeviceProperties -> Double) -> DeviceProperties -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
log2 (Double -> Double)
-> (DeviceProperties -> Double) -> DeviceProperties -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (Int -> Double)
-> (DeviceProperties -> Int) -> DeviceProperties -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeviceProperties -> Int
CUDA.warpSize (DeviceProperties -> Int) -> DeviceProperties -> Int
forall a b. (a -> b) -> a -> b
$ DeviceProperties
dev

    -- Return whether the index is valid. Assume that constant branches are
    -- optimised away.
    valid :: Operands Int32 -> CodeGen PTX (Operands Bool)
valid Operands Int32
i =
      case Maybe (Operands Int32)
size of
        Maybe (Operands Int32)
Nothing -> Operands Bool -> CodeGen PTX (Operands Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Operands Bool
liftBool Bool
True)
        Just Operands Int32
n  -> SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
i Operands Int32
n

    -- Unfold the reduction as a recursive code generation function.
    reduce :: Int -> Operands e -> CodeGen PTX (Operands e)
    reduce :: Int -> Operands e -> CodeGen PTX (Operands e)
reduce Int
step Operands e
x
      | Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
steps = Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x
      | Bool
otherwise     = do
          let offset :: Operands Int32
offset = Int32 -> Operands Int32
liftInt32 (Int32
1 Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`P.shiftL` Int
step)

          -- share input through buffer
          Operands Int32
lane <- CodeGen PTX (Operands Int32)
laneId
          IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX DIM0
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch DIM0
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
lane Operands e
x

          CodeGen PTX DIM0
HasCallStack => CodeGen PTX DIM0
__syncwarp

          -- update input if in range
          Operands Int32
i   <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
lane Operands Int32
offset
          Operands e
x'  <- if (TypeR e
tp, Operands Int32 -> CodeGen PTX (Operands Bool)
valid Operands Int32
i)
                   then IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x (Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e) -> CodeGen PTX (Operands e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
i
                   else Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x

          CodeGen PTX DIM0
HasCallStack => CodeGen PTX DIM0
__syncwarp

          Int -> Operands e -> CodeGen PTX (Operands e)
reduce (Int
stepInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Operands e
x'


-- Efficient warp reduction using __shfl_up instruction (compute >= 3.0)
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.2/cub/warp/specializations/warp_reduce_shfl.cuh#L310
--
-- reduceWarpShfl
--     :: IRFun2 PTX aenv (e -> e -> e)                            -- ^ combination function
--     -> Operands e                                                     -- ^ this thread's input value
--     -> CodeGen (Operands e)                                           -- ^ final result
-- reduceWarpShfl combine input =
--   error "TODO: PTX.reduceWarpShfl"


-- Reduction loops
-- ---------------

reduceFromTo
    :: DeviceProperties
    -> TypeR a
    -> Operands Int                                   -- ^ starting index
    -> Operands Int                                   -- ^ final index (exclusive)
    -> (IRFun2 PTX aenv (a -> a -> a))          -- ^ combination function
    -> (Operands Int -> CodeGen PTX (Operands a))           -- ^ function to retrieve element at index
    -> (Operands a -> CodeGen PTX ())                 -- ^ what to do with the value
    -> CodeGen PTX ()
reduceFromTo :: DeviceProperties
-> TypeR a
-> Operands Int
-> Operands Int
-> IRFun2 PTX aenv (a -> a -> a)
-> (Operands Int -> CodeGen PTX (Operands a))
-> (Operands a -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
reduceFromTo DeviceProperties
dev TypeR a
tp Operands Int
from Operands Int
to IRFun2 PTX aenv (a -> a -> a)
combine Operands Int -> CodeGen PTX (Operands a)
get Operands a -> CodeGen PTX DIM0
set = do

  Operands Int
tid   <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
threadIdx
  Operands Int
bd    <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim

  Operands Int
valid <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
to Operands Int
from
  Operands Int
i     <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
from Operands Int
tid

  Operands DIM0
_     <- if (TupR ScalarType DIM0
forall (s :: * -> *). TupR s DIM0
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
valid Operands Int
bd)
             then do
               -- All threads in the block will participate in the reduction, so
               -- we can avoid bounds checks
               Operands a
x <- Operands Int -> CodeGen PTX (Operands a)
get Operands Int
i
               Operands a
r <- DeviceProperties
-> TypeR a
-> IRFun2 PTX aenv (a -> a -> a)
-> Maybe (Operands Int32)
-> Operands a
-> CodeGen PTX (Operands a)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR a
tp IRFun2 PTX aenv (a -> a -> a)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands a
x
               Operands a -> CodeGen PTX DIM0
set Operands a
r

               Operands DIM0 -> CodeGen PTX (Operands DIM0)
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType DIM0 -> DIM0 -> Operands DIM0
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType DIM0
forall (s :: * -> *). TupR s DIM0
TupRunit ())
             else do
               -- Only in-bounds threads can read their input and participate in
               -- the reduction
               CodeGen PTX (Operands Bool) -> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch DIM0 -> CodeGen arch DIM0
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
to) (CodeGen PTX DIM0 -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0 -> CodeGen PTX DIM0
forall a b. (a -> b) -> a -> b
$ do
                 Operands a
x <- Operands Int -> CodeGen PTX (Operands a)
get Operands Int
i
                 Operands Int32
v <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
valid
                 Operands a
r <- DeviceProperties
-> TypeR a
-> IRFun2 PTX aenv (a -> a -> a)
-> Maybe (Operands Int32)
-> Operands a
-> CodeGen PTX (Operands a)
forall aenv e.
DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
reduceBlockSMem DeviceProperties
dev TypeR a
tp IRFun2 PTX aenv (a -> a -> a)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
v) Operands a
x
                 Operands a -> CodeGen PTX DIM0
set Operands a
r

               Operands DIM0 -> CodeGen PTX (Operands DIM0)
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType DIM0 -> DIM0 -> Operands DIM0
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType DIM0
forall (s :: * -> *). TupR s DIM0
TupRunit ())

  DIM0 -> CodeGen PTX DIM0
forall (m :: * -> *) a. Monad m => a -> m a
return ()


-- Utilities
-- ---------

i32 :: Operands Int -> CodeGen PTX (Operands Int32)
i32 :: Operands Int -> CodeGen PTX (Operands Int32)
i32 = IntegralType Int
-> NumType Int32 -> Operands Int -> CodeGen PTX (Operands Int32)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType NumType Int32
forall a. IsNum a => NumType a
numType

int :: Operands Int32 -> CodeGen PTX (Operands Int)
int :: Operands Int32 -> CodeGen PTX (Operands Int)
int = IntegralType Int32
-> NumType Int -> Operands Int32 -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType NumType Int
forall a. IsNum a => NumType a
numType

imapFromTo
    :: Operands Int
    -> Operands Int
    -> (Operands Int -> CodeGen PTX ())
    -> CodeGen PTX ()
imapFromTo :: Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
imapFromTo Operands Int
start Operands Int
end Operands Int -> CodeGen PTX DIM0
body = do
  Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
  Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
  Operands Int
i0  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
start Operands Int
bid
  Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX DIM0)
-> CodeGen PTX DIM0
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch DIM0)
-> CodeGen arch DIM0
imapFromStepTo Operands Int
i0 Operands Int
gd Operands Int
end Operands Int -> CodeGen PTX DIM0
body