{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan (

  mkScan, mkScan',

) where

import Data.Array.Accelerate.AST                                    ( Direction(..) )
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Target

import LLVM.AST.Type.Representation

import qualified Foreign.CUDA.Analysis                              as CUDA

import Control.Applicative
import Control.Monad                                                ( (>=>), void )
import Control.Monad.State                                          ( gets )
import Data.String                                                  ( fromString )
import Data.Coerce                                                  as Safe
import Data.Bits                                                    as P
import Prelude                                                      as P hiding ( last )



-- 'Data.List.scanl' or 'Data.List.scanl1' style exclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation.
--
-- > scanl (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [10,10,11,13,16,20,25,31,38,46,55]
--
mkScan
    :: forall aenv sh e.
       Gamma            aenv
    -> ArrayR (Array (sh, Int) e)
    -> Direction
    -> IRFun2       PTX aenv (e -> e -> e)
    -> Maybe (IRExp PTX aenv e)
    -> MIRDelayed   PTX aenv (Array (sh, Int) e)
    -> CodeGen      PTX      (IROpenAcc PTX aenv (Array (sh, Int) e))
mkScan :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> Direction
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkScan Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr Direction
dir IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr
  = (IROpenAcc PTX aenv (Array (sh, Int) e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e))
-> [IROpenAcc PTX aenv (Array (sh, Int) e)]
-> IROpenAcc PTX aenv (Array (sh, Int) e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
-> IROpenAcc PTX aenv (Array (sh, Int) e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) ([IROpenAcc PTX aenv (Array (sh, Int) e)]
 -> IROpenAcc PTX aenv (Array (sh, Int) e))
-> CodeGen PTX [IROpenAcc PTX aenv (Array (sh, Int) e)]
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
-> CodeGen PTX [IROpenAcc PTX aenv (Array (sh, Int) e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
codeScan [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
-> [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
-> [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
forall a. [a] -> [a] -> [a]
++ [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
codeFill)

  where
    codeScan :: [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
codeScan = case ArrayR (Array (sh, Int) e)
repr of
      ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
tp -> [ Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 Direction
dir Gamma aenv
aenv TypeR e
tp   IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
MIRExp PTX aenv e
seed MIRDelayed PTX aenv (Array (sh, Int) e)
MIRDelayed PTX aenv (Vector e)
arr
                                        , Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 Direction
dir Gamma aenv
aenv TypeR e
tp   IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine
                                        , Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 Direction
dir Gamma aenv
aenv TypeR e
tp   IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
MIRExp PTX aenv e
seed
                                        ]
      ArrayR (Array (sh, Int) e)
_                              -> [ Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh e.
Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkScanDim   Direction
dir Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRFun2 PTX aenv (e -> e -> e)
combine Maybe (IRExp PTX aenv e)
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr
                                        ]
    codeFill :: [CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))]
codeFill = case Maybe (IRExp PTX aenv e)
seed of
      Just IRExp PTX aenv e
s -> [ Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkScanFill Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRExp PTX aenv e
s ]
      Maybe (IRExp PTX aenv e)
Nothing -> []

-- Variant of 'scanl' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [10,10,11,13,16,20,25,31,38,46]
--       , Array Z [55]
--       )
--
mkScan'
    :: forall aenv sh e.
       Gamma          aenv
    -> ArrayR (Array (sh, Int) e)
    -> Direction
    -> IRFun2     PTX aenv (e -> e -> e)
    -> IRExp      PTX aenv e
    -> MIRDelayed PTX aenv (Array (sh, Int) e)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan' :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> Direction
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan' Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr Direction
dir IRFun2 PTX aenv (e -> e -> e)
combine IRExp PTX aenv e
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr
  | ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
tp <- ArrayR (Array (sh, Int) e)
repr
  = (IROpenAcc PTX aenv (Vector e, Scalar e)
 -> IROpenAcc PTX aenv (Vector e, Scalar e)
 -> IROpenAcc PTX aenv (Vector e, Scalar e))
-> [IROpenAcc PTX aenv (Vector e, Scalar e)]
-> IROpenAcc PTX aenv (Vector e, Scalar e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc PTX aenv (Vector e, Scalar e)
-> IROpenAcc PTX aenv (Vector e, Scalar e)
-> IROpenAcc PTX aenv (Vector e, Scalar e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) ([IROpenAcc PTX aenv (Vector e, Scalar e)]
 -> IROpenAcc PTX aenv (Vector e, Scalar e))
-> CodeGen PTX [IROpenAcc PTX aenv (Vector e, Scalar e)]
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))]
-> CodeGen PTX [IROpenAcc PTX aenv (Vector e, Scalar e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRExp PTX aenv e
IRExp PTX aenv e
seed MIRDelayed PTX aenv (Array (sh, Int) e)
MIRDelayed PTX aenv (Vector e)
arr
                              , Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine
                              , Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine
                              , Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRExp PTX aenv e
seed
                              ]
  --
  | Bool
otherwise
  = IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
forall aenv a.
IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
(+++) (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     PTX
     (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
      -> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Dim Direction
dir Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRFun2 PTX aenv (e -> e -> e)
combine IRExp PTX aenv e
seed MIRDelayed PTX aenv (Array (sh, Int) e)
arr
          CodeGen
  PTX
  (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
   -> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill  Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRExp PTX aenv e
seed


-- Device wide scans
-- -----------------
--
-- This is a classic two-pass algorithm which proceeds in two phases and
-- requires ~4n data movement to global memory. In future we would like to
-- replace this with a single pass algorithm.
--

-- Parallel scan, step 1.
--
-- Threads scan a stripe of the input into a temporary array, incorporating the
-- initial element and any fused functions on the way. The final reduction
-- result of this chunk is written to a separate array.
--
mkScanAllP1
    :: forall aenv e.
       Direction
    -> Gamma          aenv                      -- ^ array environment
    -> TypeR e
    -> IRFun2     PTX aenv (e -> e -> e)        -- ^ combination function
    -> MIRExp     PTX aenv e                    -- ^ seed element, if this is an exclusive scan
    -> MIRDelayed PTX aenv (Vector e)           -- ^ input data
    -> CodeGen    PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Vector e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      (IRDelayed PTX aenv (Vector e)
arrIn,  [Parameter]
paramIn)   = Name (Vector e)
-> MIRDelayed PTX aenv (Vector e)
-> (IRDelayed PTX aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed PTX aenv (Vector e)
marr
      end :: Operands Int
end                 = Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP1" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    -- Size of the input array
    Operands Int
sz  <- Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands DIM1 -> Operands Int)
-> CodeGen PTX (Operands DIM1) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Vector e) -> CodeGen PTX (Operands DIM1)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Vector e)
arrIn

    -- A thread block scans a non-empty stripe of the input, storing the final
    -- block-wide aggregate into a separate array
    --
    -- For exclusive scans, thread 0 of segment 0 must incorporate the initial
    -- element into the input and output. Threads shuffle their indices
    -- appropriately.
    --
    Operands Int32
bid <- CodeGen PTX (Operands Int32)
blockIdx
    Operands Int32
gd  <- CodeGen PTX (Operands Int32)
gridDim
    Operands Int
gd' <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
gd
    Operands Int
s0  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bid

    -- iterating over thread-block-wide segments
    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
s0 Operands Int
gd' Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
chunk -> do

      Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
      Operands Int
bd'   <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bd
      Operands Int
inf   <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
chunk Operands Int
bd'

      -- index i* is the index that this thread will read data from. Recall that
      -- the supremum index is exclusive
      Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
      Operands Int
tid'  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
      Operands Int
i0    <- case Direction
dir of
                 Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
tid'
                 Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
inf
                                   Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid'
                                   Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y (Int -> Operands Int
liftInt Int
1)
                                   Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
z

      -- index j* is the index that we write to. Recall that for exclusive scans
      -- the output array is one larger than the input; the initial element will
      -- be written into this spot by thread 0 of the first thread block.
      Operands Int
j0    <- case MIRExp PTX aenv e
mseed of
                 MIRExp PTX aenv e
Nothing -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0
                 Just IRExp PTX aenv e
_  -> case Direction
dir of
                              Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i0 (Int -> Operands Int
liftInt Int
1)
                              Direction
RightToLeft -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0

      -- If this thread has input, read data and participate in thread-block scan
      let valid :: Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i = case Direction
dir of
                      Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
sz
                      Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Int -> Operands Int
liftInt Int
0)

      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i0) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands e
x0 <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Vector e) -> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Vector e)
arrIn) Operands Int
i0
        Operands e
x1 <- case MIRExp PTX aenv e
mseed of
                MIRExp PTX aenv e
Nothing   -> Operands e -> IRExp PTX aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x0
                Just IRExp PTX aenv e
seed ->
                  if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0) CodeGen PTX (Operands Bool)
-> CodeGen PTX (Operands Bool) -> CodeGen PTX (Operands Bool)
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch (Operands Bool) -> CodeGen arch (Operands Bool)
`A.land'` SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
chunk (Int -> Operands Int
liftInt Int
0))
                    then do
                      Operands e
z <- IRExp PTX aenv e
seed
                      case Direction
dir of
                        Direction
LeftToRight -> IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
arrOut (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
z CodeGen PTX () -> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
z Operands e
x0
                        Direction
RightToLeft -> IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt   IRArray (Vector e)
arrOut Operands Int
sz            Operands e
z CodeGen PTX () -> IRExp PTX aenv e -> IRExp PTX aenv e
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x0 Operands e
z
                    else
                      Operands e -> IRExp PTX aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x0

        Operands Int
n  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
inf
        Operands Int32
n' <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
        Operands e
x2 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd')
                then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IRExp PTX aenv e
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
x1
                else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IRExp PTX aenv e
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
n') Operands e
x1

        -- Write this thread's scan result to memory
        IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j0 Operands e
x2

        -- The last thread also writes its result---the aggregate for this
        -- thread block---to the temporary partial sums array. This is only
        -- necessary for full blocks in a multi-block scan; the final
        -- partially-full tile does not have a successor block.
        Operands Int32
last <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bd (Int32 -> Operands Int32
liftInt32 Int32
1)
        CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
gd (Int32 -> Operands Int32
liftInt32 Int32
1) CodeGen PTX (Operands Bool)
-> CodeGen PTX (Operands Bool) -> CodeGen PTX (Operands Bool)
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch (Operands Bool) -> CodeGen arch (Operands Bool)
`land'` SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
last) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
          case Direction
dir of
            Direction
LeftToRight -> IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
chunk Operands e
x2
            Direction
RightToLeft -> do Operands Int
u <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
chunk
                              Operands Int
v <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
u (Int -> Operands Int
liftInt Int
1)
                              IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
v Operands e
x2

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan, step 2
--
-- A single thread block performs a scan of the per-block aggregates computed in
-- step 1. This gives the per-block prefix which must be added to each element
-- in step 3.
--
mkScanAllP2
    :: forall aenv e.
       Direction
    -> Gamma       aenv                         -- ^ array environment
    -> TypeR e
    -> IRFun2  PTX aenv (e -> e -> e)           -- ^ combination function
    -> CodeGen PTX      (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start               = Int -> Operands Int
liftInt Int
0
      end :: Operands Int
end                 = Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall p p p. Num p => p -> p -> p
grid Q (TExp (Int -> Int -> Int))
forall p p. Q (TExp (p -> p -> Int))
gridQ
      grid :: p -> p -> p
grid p
_ p
_            = p
1
      gridQ :: Q (TExp (p -> p -> Int))
gridQ               = [|| \_ _ -> 1 ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP2" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: We could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    IRArray (Vector e)
carry <- TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
1

    Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
    Operands Int
bd'   <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bd

    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
start Operands Int
bd' Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset -> do

      -- Index of the partial sums array that this thread will process.
      Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
      Operands Int
tid'  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
      Operands Int
i0    <- case Direction
dir of
                 Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset Operands Int
tid'
                 Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
offset
                                   Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid'
                                   Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y (Int -> Operands Int
liftInt Int
1)
                                   Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
z

      let valid :: Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i = case Direction
dir of
                      Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
end
                      Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
start

      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i0) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do

        -- wait for the carry-in value to be updated
        CodeGen PTX ()
__syncthreads

        Operands e
x0 <- IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0
        Operands e
x1 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
offset (Int -> Operands Int
liftInt Int
0) CodeGen PTX (Operands Bool)
-> CodeGen PTX (Operands Bool) -> CodeGen PTX (Operands Bool)
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch (Operands Bool) -> CodeGen arch (Operands Bool)
`land'` SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
                then do
                  Operands e
c <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)
                  case Direction
dir of
                    Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
c Operands e
x0
                    Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x0 Operands e
c
                else do
                  Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x0

        Operands Int
n  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
offset
        Operands Int32
n' <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
        Operands e
x2 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd')
                then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
x1
                else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
n') Operands e
x1

        -- Update the temporary array with this thread's result
        IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0 Operands e
x2

        -- The last thread writes the carry-out value. If the last thread is not
        -- active, then this must be the last stripe anyway.
        Operands Int32
last <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bd (Int32 -> Operands Int32
liftInt32 Int32
1)
        CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
last) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
          IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
x2

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan, step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed in step 2.
--
mkScanAllP3
    :: forall aenv e.
       Direction
    -> Gamma       aenv                         -- ^ array environment
    -> TypeR e
    -> IRFun2  PTX aenv (e -> e -> e)           -- ^ combination function
    -> MIRExp  PTX aenv e                       -- ^ seed element, if this is an exclusive scan
    -> CodeGen PTX      (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      stride :: Operands Int
stride              = TypeR Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.stride"
      paramStride :: [Parameter]
paramStride         = TypeR Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.stride"
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP3" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramStride [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
sz  <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> CodeGen PTX (Operands Int))
-> Operands Int -> CodeGen PTX (Operands Int)
forall a b. (a -> b) -> a -> b
$ Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrOut)
    Operands Int
tid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
threadIdx

    -- Threads that will never contribute can just exit immediately. The size of
    -- each chunk is set by the block dimension of the step 1 kernel, which may
    -- be different from the block size of this kernel.
    CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid Operands Int
stride) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do

      -- Iterate over the segments computed in phase 1. Note that we have one
      -- fewer chunk to process because the first has no carry-in.
      Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
      Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
      Operands Int
end <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType (Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)) (Int -> Operands Int
liftInt Int
1)

      Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
bid Operands Int
gd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
chunk -> do

        -- Determine the start and end indicies of this chunk to which we will
        -- carry-in the value. Returned for left-to-right traversal.
        (Operands Int
inf,Operands Int
sup) <- case Direction
dir of
                       Direction
LeftToRight -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
chunk (Int -> Operands Int
liftInt Int
1)
                         Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
stride Operands Int
a
                         case MIRExp PTX aenv e
mseed of
                           Just{}  -> do
                             Operands Int
c <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
b (Int -> Operands Int
liftInt Int
1)
                             Operands Int
d <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
c Operands Int
stride
                             Operands Int
e <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
d Operands Int
sz
                             (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
c,Operands Int
e)
                           MIRExp PTX aenv e
Nothing -> do
                             Operands Int
c <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
b Operands Int
stride
                             Operands Int
d <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
c Operands Int
sz
                             (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
b,Operands Int
d)
                       Direction
RightToLeft -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
chunk
                         Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
stride Operands Int
a
                         Operands Int
c <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
b
                         case MIRExp PTX aenv e
mseed of
                           Just{}  -> do
                             Operands Int
d <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
c (Int -> Operands Int
liftInt Int
1)
                             Operands Int
e <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
d Operands Int
stride
                             Operands Int
f <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.max SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
e (Int -> Operands Int
liftInt Int
0)
                             (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
f,Operands Int
d)
                           MIRExp PTX aenv e
Nothing -> do
                             Operands Int
d <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
c Operands Int
stride
                             Operands Int
e <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.max SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
d (Int -> Operands Int
liftInt Int
0)
                             (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
e,Operands Int
c)

        -- Read the carry-in value
        Operands e
carry     <- case Direction
dir of
                       Direction
LeftToRight -> IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
chunk
                       Direction
RightToLeft -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
chunk (Int -> Operands Int
liftInt Int
1)
                         Operands e
b <- IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
a
                         Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
b

        -- Apply the carry-in value to each element in the chunk
        Operands Int
bd        <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
        Operands Int
i0        <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
tid
        Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
i0 Operands Int
bd Operands Int
sup ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
          Operands e
v <- IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i
          Operands e
u <- case Direction
dir of
                 Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
carry Operands e
v
                 Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
v Operands e
carry
          IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i Operands e
u

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan', step 1.
--
-- Similar to mkScanAllP1. Threads scan a stripe of the input into a temporary
-- array, incorporating the initial element and any fused functions on the way.
-- The final reduction result of this chunk is written to a separate array.
--
mkScan'AllP1
    :: forall aenv e.
       Direction
    -> Gamma          aenv
    -> TypeR e
    -> IRFun2     PTX aenv (e -> e -> e)
    -> IRExp      PTX aenv e
    -> MIRDelayed PTX aenv (Vector e)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Vector e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRExp PTX aenv e
seed MIRDelayed PTX aenv (Vector e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      (IRDelayed PTX aenv (Vector e)
arrIn,  [Parameter]
paramIn)   = Name (Vector e)
-> MIRDelayed PTX aenv (Vector e)
-> (IRDelayed PTX aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed PTX aenv (Vector e)
marr
      end :: Operands Int
end                 = Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP1" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- Size of the input array
    Operands Int
sz  <- Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands DIM1 -> Operands Int)
-> CodeGen PTX (Operands DIM1) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Vector e) -> CodeGen PTX (Operands DIM1)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Vector e)
arrIn

    -- A thread block scans a non-empty stripe of the input, storing the partial
    -- result and the final block-wide aggregate
    Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
    Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim

    -- iterate over thread-block wide segments
    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
bid Operands Int
gd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

      Operands Int
bd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
      Operands Int
inf <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg Operands Int
bd

      -- i* is the index that this thread will read data from
      Operands Int
tid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
threadIdx
      Operands Int
i0  <- case Direction
dir of
               Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
tid
               Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
inf
                                 Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid
                                 Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y (Int -> Operands Int
liftInt Int
1)
                                 Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
z

      -- j* is the index this thread will write to. This is just shifted by one
      -- to make room for the initial element
      Operands Int
j0  <- case Direction
dir of
               Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i0 (Int -> Operands Int
liftInt Int
1)
               Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i0 (Int -> Operands Int
liftInt Int
1)

      -- If this thread has input it participates in the scan
      let valid :: Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i = case Direction
dir of
                      Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
sz
                      Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Int -> Operands Int
liftInt Int
0)

      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i0) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands e
x0 <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Vector e) -> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Vector e)
arrIn) Operands Int
i0

        -- Thread 0 of the first segment must also evaluate and store the
        -- initial element
        Operands Int32
ti <- CodeGen PTX (Operands Int32)
threadIdx
        Operands e
x1 <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
ti (Int32 -> Operands Int32
liftInt32 Int32
0) CodeGen PTX (Operands Bool)
-> CodeGen PTX (Operands Bool) -> CodeGen PTX (Operands Bool)
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch (Operands Bool) -> CodeGen arch (Operands Bool)
`A.land'` SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
seg (Int -> Operands Int
liftInt Int
0))
                then do
                  Operands e
z <- IRExp PTX aenv e
seed
                  IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i0 Operands e
z
                  case Direction
dir of
                    Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
z Operands e
x0
                    Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x0 Operands e
z
                else
                  Operands e -> IRExp PTX aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x0

        -- Block-wide scan
        Operands Int
n  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
inf
        Operands Int32
n' <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
        Operands e
x2 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd)
                then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IRExp PTX aenv e
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
x1
                else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> IRExp PTX aenv e
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
n') Operands e
x1

        -- Write this thread's scan result to memory. Recall that we had to make
        -- space for the initial element, so the very last thread does not store
        -- its result here.
        case Direction
dir of
          Direction
LeftToRight -> CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
j0 Operands Int
sz)          (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j0 Operands e
x2
          Direction
RightToLeft -> CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
j0 (Int -> Operands Int
liftInt Int
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j0 Operands e
x2

        -- Last active thread writes its result to the partial sums array. These
        -- will be used to compute the carry-in value in step 2.
        Operands Int
m  <- do Operands Int
x <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd
                 Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x (Int -> Operands Int
liftInt Int
1)
                 Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
y
        CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid Operands Int
m) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
          case Direction
dir of
            Direction
LeftToRight -> IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
seg Operands e
x2
            Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
seg
                              Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x (Int -> Operands Int
liftInt Int
1)
                              IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
y Operands e
x2

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan', step 2
--
-- A single thread block performs an inclusive scan of the partial sums array to
-- compute the per-block carry-in values, as well as the final reduction result.
--
mkScan'AllP2
    :: forall aenv e.
       Direction
    -> Gamma aenv
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)
    -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      (IRArray (Scalar e)
arrSum, [Parameter]
paramSum)  = ArrayR (Scalar e)
-> Name (Scalar e) -> (IRArray (Scalar e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
tp) Name (Scalar e)
"sum"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start               = Int -> Operands Int
liftInt Int
0
      end :: Operands Int
end                 = Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall p p p. Num p => p -> p -> p
grid Q (TExp (Int -> Int -> Int))
forall p p. Q (TExp (p -> p -> Int))
gridQ
      grid :: p -> p -> p
grid p
_ p
_            = p
1
      gridQ :: Q (TExp (p -> p -> Int))
gridQ               = [|| \_ _ -> 1 ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP2" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSum [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    IRArray (Vector e)
carry <- TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
1

    -- A single thread block iterates over the per-block partial results from
    -- step 1
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int
tid'  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
    Operands Int
bd    <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim

    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
start Operands Int
bd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
offset -> do

      Operands Int
i0  <- case Direction
dir of
               Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
offset Operands Int
tid'
               Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
offset
                                 Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid'
                                 Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y (Int -> Operands Int
liftInt Int
1)
                                 Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
z

      let valid :: Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i = case Direction
dir of
                      Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
end
                      Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
start

      -- wait for the carry-in value to be updated
      CodeGen PTX ()
__syncthreads

      Operands e
x0 <- if (TypeR e
tp, Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i0)
              then IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0
              else
                let go :: TypeR a -> Operands a
                    go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                    go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                    go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                in
                Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

      Operands e
x1 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
offset (Int -> Operands Int
liftInt Int
0) CodeGen PTX (Operands Bool)
-> CodeGen PTX (Operands Bool) -> CodeGen PTX (Operands Bool)
forall arch.
CodeGen arch (Operands Bool)
-> CodeGen arch (Operands Bool) -> CodeGen arch (Operands Bool)
`A.land'` SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
              then do
                Operands e
c <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)
                case Direction
dir of
                  Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
c Operands e
x0
                  Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x0 Operands e
c
              else
                Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x0

      Operands Int
n  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
end Operands Int
offset
      Operands Int32
n' <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
      Operands e
x2 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd)
              then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
x1
              else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
n') Operands e
x1

      -- Update the partial results array
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands Int -> CodeGen PTX (Operands Bool)
valid Operands Int
i0) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
        IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0 Operands e
x2

      -- The last active thread saves its result as the carry-out value.
      Operands Int32
m  <- do Operands Int
x <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
bd Operands Int
n
               Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x (Int -> Operands Int
liftInt Int
1)
               Operands Int32
z <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
y
               Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
z
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
m) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
        IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
x2

    -- First thread stores the final carry-out values at the final reduction
    -- result for the entire array
    CodeGen PTX ()
__syncthreads

    CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
      IntegralType Int32
-> IRArray (Scalar e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Scalar e)
arrSum (Int32 -> Operands Int32
liftInt32 Int32
0) (Operands e -> CodeGen PTX ())
-> CodeGen PTX (Operands e) -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan', step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed in step 2.
--
mkScan'AllP3
    :: forall aenv e.
       Direction
    -> Gamma aenv                                   -- ^ array environment
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 :: Direction
-> Gamma aenv
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 Direction
dir Gamma aenv
aenv TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)  = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      stride :: Operands Int
stride              = TypeR Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.stride"
      paramStride :: [Parameter]
paramStride         = TypeR Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.stride"
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scanP3" ([Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramStride [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
sz  <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> CodeGen PTX (Operands Int))
-> Operands Int -> CodeGen PTX (Operands Int)
forall a b. (a -> b) -> a -> b
$ Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrOut)
    Operands Int
tid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
threadIdx

    CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid Operands Int
stride) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do

      Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
      Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
      Operands Int
end <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType (Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Vector e) -> Operands DIM1
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Vector e)
arrTmp)) (Int -> Operands Int
liftInt Int
1)

      Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
bid Operands Int
gd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
chunk -> do

        (Operands Int
inf,Operands Int
sup) <- case Direction
dir of
                       Direction
LeftToRight -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
chunk  (Int -> Operands Int
liftInt Int
1)
                         Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
stride Operands Int
a
                         Operands Int
c <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
b      (Int -> Operands Int
liftInt Int
1)
                         Operands Int
d <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
c      Operands Int
stride
                         Operands Int
e <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
d      Operands Int
sz
                         (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
c,Operands Int
e)
                       Direction
RightToLeft -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
end    Operands Int
chunk
                         Operands Int
b <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
stride Operands Int
a
                         Operands Int
c <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
sz     Operands Int
b
                         Operands Int
d <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
c      (Int -> Operands Int
liftInt Int
1)
                         Operands Int
e <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType    Operands Int
d      Operands Int
stride
                         Operands Int
f <- SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.max SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
e      (Int -> Operands Int
liftInt Int
0)
                         (Operands Int, Operands Int)
-> CodeGen PTX (Operands Int, Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int
f,Operands Int
d)

        Operands e
carry     <- case Direction
dir of
                       Direction
LeftToRight -> IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
chunk
                       Direction
RightToLeft -> do
                         Operands Int
a <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
chunk (Int -> Operands Int
liftInt Int
1)
                         Operands e
b <- IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
a
                         Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
b

        -- Apply the carry-in value to each element in the chunk
        Operands Int
bd        <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockDim
        Operands Int
i0        <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
tid
        Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
i0 Operands Int
bd Operands Int
sup ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
          Operands e
v <- IntegralType Int
-> IRArray (Vector e) -> Operands Int -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i
          Operands e
u <- case Direction
dir of
                 Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
carry Operands e
v
                 Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
v Operands e
carry
          IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i Operands e
u

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Multidimensional scans
-- ----------------------

-- Multidimensional scan along the innermost dimension
--
-- A thread block individually computes along each innermost dimension. This is
-- a single-pass operation.
--
--  * We can assume that the array is non-empty; exclusive scans with empty
--    innermost dimension will be instead filled with the seed element via
--    'mkScanFill'.
--
--  * Small but non-empty innermost dimension arrays (size << thread
--    block size) will have many threads which do no work.
--
mkScanDim
    :: forall aenv sh e.
       Direction
    -> Gamma          aenv                          -- ^ array environment
    -> ArrayR (Array (sh, Int) e)
    -> IRFun2     PTX aenv (e -> e -> e)            -- ^ combination function
    -> MIRExp     PTX aenv e                        -- ^ seed element, if this is an exclusive scan
    -> MIRDelayed PTX aenv (Array (sh, Int) e)      -- ^ input data
    -> CodeGen    PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkScanDim :: Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> MIRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
mkScanDim Direction
dir Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR (ShapeRsnoc ShapeR sh1
shr) TypeR e
tp) IRFun2 PTX aenv (e -> e -> e)
combine MIRExp PTX aenv e
mseed MIRDelayed PTX aenv (Array (sh, Int) e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
      (IRDelayed PTX aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)   = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed PTX aenv (Array (sh, Int) e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scan" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: we could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    IRArray (Vector e)
carry <- TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
1

    -- Size of the input array
    Operands Int
sz  <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn

    -- Thread blocks iterate over the outer dimensions. Threads in a block
    -- cooperatively scan along one dimension, but thread blocks do not
    -- communicate with each other.
    --
    Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
    Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
    Operands Int
end <- ShapeR sh1 -> Operands sh1 -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh1
shr (Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut))

    Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
bid Operands Int
gd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

      -- Index this thread reads from
      Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
      Operands Int
tid'  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
      Operands Int
i0    <- case Direction
dir of
                 Direction
LeftToRight -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg Operands Int
sz
                                   Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid'
                                   Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
y

                 Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg (Int -> Operands Int
liftInt Int
1)
                                   Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
sz
                                   Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y Operands Int
tid'
                                   Operands Int
w <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
z (Int -> Operands Int
liftInt Int
1)
                                   Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
w

      -- Index this thread writes to
      Operands Int
j0  <- case MIRExp PTX aenv e
mseed of
               MIRExp PTX aenv e
Nothing -> Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0
               Just{}  -> do Operands Int
szp1 <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands Int -> CodeGen PTX (Operands Int))
-> Operands Int -> CodeGen PTX (Operands Int)
forall a b. (a -> b) -> a -> b
$ Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut)
                             case Direction
dir of
                               Direction
LeftToRight -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg Operands Int
szp1
                                                 Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
tid'
                                                 Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
y

                               Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg (Int -> Operands Int
liftInt Int
1)
                                                 Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x Operands Int
szp1
                                                 Operands Int
z <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
y Operands Int
tid'
                                                 Operands Int
w <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
z (Int -> Operands Int
liftInt Int
1)
                                                 Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
w

      -- Stride indices by block dimension
      Operands Int32
bd  <- CodeGen PTX (Operands Int32)
blockDim
      Operands Int
bd' <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bd
      let next :: Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
ix = case Direction
dir of
                      Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
ix Operands Int
bd'
                      Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
ix Operands Int
bd'

      -- Initialise this scan segment
      --
      -- If this is an exclusive scan then the first thread just evaluates the
      -- seed element and stores this value into the carry-in slot. All threads
      -- shift their write-to index (j) by one, to make space for this element.
      --
      -- If this is an inclusive scan then do a block-wide scan. The last thread
      -- in the block writes the carry-in value.
      --
      Operands (Tup3 Int Int Int)
r <-
        case MIRExp PTX aenv e
mseed of
          Just IRExp PTX aenv e
seed -> do
            CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
              Operands e
z <- IRExp PTX aenv e
seed
              IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt   IRArray (Array (sh, Int) e)
arrOut Operands Int
j0 Operands e
z
              IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
Operands e
z
            Operands Int
j1 <- case Direction
dir of
                   Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
j0 (Int -> Operands Int
liftInt Int
1)
                   Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
j0 (Int -> Operands Int
liftInt Int
1)
            Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (Tup3 Int Int Int)
 -> CodeGen PTX (Operands (Tup3 Int Int Int)))
-> Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall a b. (a -> b) -> a -> b
$ Operands Int
-> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
sz Operands Int
i0 Operands Int
j1

          MIRExp PTX aenv e
Nothing -> do
            CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid' Operands Int
sz) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
              Operands Int32
n' <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
sz
              Operands e
x0 <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i0
              Operands e
r0 <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
sz Operands Int
bd')
                      then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing   Operands e
Operands e
x0
                      else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
n') Operands e
Operands e
x0
              IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
j0 Operands e
Operands e
r0

              Operands Int32
ll <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bd (Int32 -> Operands Int32
liftInt32 Int32
1)
              CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
ll) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
                IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
r0

            Operands Int
n1 <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz Operands Int
bd'
            Operands Int
i1 <- Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
i0
            Operands Int
j1 <- Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
j0
            Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (Tup3 Int Int Int)
 -> CodeGen PTX (Operands (Tup3 Int Int Int)))
-> Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall a b. (a -> b) -> a -> b
$ Operands Int
-> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
n1 Operands Int
i1 Operands Int
j1

      -- Iterate over the remaining elements in this segment
      CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ())
-> CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ TypeR (Tup3 Int Int Int)
-> (Operands (Tup3 Int Int Int) -> CodeGen PTX (Operands Bool))
-> (Operands (Tup3 Int Int Int)
    -> CodeGen PTX (Operands (Tup3 Int Int Int)))
-> Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while
        (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ScalarType () -> TypeR Int -> TupR ScalarType DIM1
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType DIM1 -> TypeR Int -> TupR ScalarType (DIM1, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType (DIM1, Int)
-> TypeR Int -> TypeR (Tup3 Int Int Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt)
        (\(Operands (Tup3 Int Int Int) -> Operands Int
forall a b c. Operands (Tup3 a b c) -> Operands a
A.fst3   -> Operands Int
n)       -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n (Int -> Operands Int
liftInt Int
0))
        (\(Operands (Tup3 Int Int Int)
-> (Operands Int, Operands Int, Operands Int)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip -> (Operands Int
n,Operands Int
i,Operands Int
j)) -> do

          -- Wait for the carry-in value from the previous iteration to be updated
          CodeGen PTX ()
__syncthreads

          -- Compute and store the next element of the scan
          --
          -- NOTE: As with 'foldSeg' we require all threads to participate in
          -- every iteration of the loop otherwise they will die prematurely.
          -- Out-of-bounds threads return 'undef' at this point, which is really
          -- unfortunate ):
          --
          Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid' Operands Int
n)
                 then IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i
                 else let
                          go :: TypeR a -> Operands a
                          go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                          go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                          go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                      in
                      Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

          -- Thread zero incorporates the carry-in element
          Operands e
y <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
                 then do
                   Operands e
c <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)
                   case Direction
dir of
                     Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
c Operands e
Operands e
x
                     Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
x Operands e
Operands e
c
                  else
                    Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x

          -- Perform the scan and write the result to memory
          Operands Int32
m <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
          Operands e
z <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd')
                 then Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing  Operands e
y
                 else Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
m) Operands e
y

          CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid' Operands Int
n) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
            IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
j Operands e
Operands e
z

            -- The last thread of the block writes its result as the carry-out
            -- value. If this thread is not active then we are on the last
            -- iteration of the loop and it will not be needed.
            Operands Int32
w <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bd (Int32 -> Operands Int32
liftInt32 Int32
1)
            CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
w) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
              IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
z

          -- Update indices for the next iteration
          Operands Int
n' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
n Operands Int
bd'
          Operands Int
i' <- Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
i
          Operands Int
j' <- Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
j
          Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (Tup3 Int Int Int)
 -> CodeGen PTX (Operands (Tup3 Int Int Int)))
-> Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall a b. (a -> b) -> a -> b
$ Operands Int
-> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
n' Operands Int
i' Operands Int
j')
        Operands (Tup3 Int Int Int)
r

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Multidimensional scan' along the innermost dimension
--
-- A thread block individually computes along each innermost dimension. This is
-- a single-pass operation.
--
--  * We can assume that the array is non-empty; exclusive scans with empty
--    innermost dimension will be instead filled with the seed element via
--    'mkScan'Fill'.
--
--  * Small but non-empty innermost dimension arrays (size << thread
--    block size) will have many threads which do no work.
--
mkScan'Dim
    :: forall aenv sh e.
       Direction
    -> Gamma          aenv                          -- ^ array environment
    -> ArrayR (Array (sh, Int) e)
    -> IRFun2     PTX aenv (e -> e -> e)            -- ^ combination function
    -> IRExp      PTX aenv e                        -- ^ seed element
    -> MIRDelayed PTX aenv (Array (sh, Int) e)      -- ^ input data
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Dim :: Direction
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Dim Direction
dir Gamma aenv
aenv repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR (ShapeRsnoc ShapeR sh1
shr) TypeR e
tp) IRFun2 PTX aenv (e -> e -> e)
combine IRExp PTX aenv e
seed MIRDelayed PTX aenv (Array (sh, Int) e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      (IRArray (Array sh e)
arrSum, [Parameter]
paramSum)  = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
repr) Name (Array sh e)
"sum"
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
repr Name (Array (sh, Int) e)
"out"
      (IRDelayed PTX aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)   = Name (Array (sh, Int) e)
-> MIRDelayed PTX aenv (Array (sh, Int) e)
-> (IRDelayed PTX aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed PTX aenv (Array (sh, Int) e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      config :: LaunchConfig
config              = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
CUDA.incWarp DeviceProperties
dev) Int -> Int
smem Int -> Int -> Int
forall a b. a -> b -> a
const [|| const ||]
      smem :: Int -> Int
smem Int
n              = Int
warps Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
per_warp) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytes
        where
          ws :: Int
ws        = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev
          warps :: Int
warps     = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
ws
          per_warp :: Int
per_warp  = Int
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ws Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2
          bytes :: Int
bytes     = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
  --
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
"scan" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSum [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX ()
 -> CodeGen
      PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)))
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: we could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    IRArray (Vector e)
carry <- TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
forall e. TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
1

    -- Size of the input array
    Operands Int
sz    <- Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (Operands (sh, Int) -> Operands Int)
-> CodeGen PTX (Operands (sh, Int)) -> CodeGen PTX (Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRDelayed PTX aenv (Array (sh, Int) e)
-> CodeGen PTX (Operands (sh, Int))
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array (sh, Int) e)
arrIn

    -- If the innermost dimension is smaller than the number of threads in the
    -- block, those threads will never contribute to the output.
    Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
    Operands Int
tid'  <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
tid
    CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid' Operands Int
sz) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do

      -- Thread blocks iterate over the outer dimensions, each thread block
      -- cooperatively scanning along each outermost index.
      Operands Int
bid <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
blockIdx
      Operands Int
gd  <- Operands Int32 -> CodeGen PTX (Operands Int)
int (Operands Int32 -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CodeGen PTX (Operands Int32)
gridDim
      Operands Int
end <- ShapeR sh1 -> Operands sh1 -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh1
shr (IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrSum)

      Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
forall i arch.
IsNum i =>
Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands Int
bid Operands Int
gd Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
seg -> do

        -- Not necessary to wait for threads to catch up before starting this segment
        -- __syncthreads

        -- Linear index bounds for this segment
        Operands Int
inf <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int
forall a. IsNum a => NumType a
numType Operands Int
seg Operands Int
sz
        Operands Int
sup <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
sz

        -- Index that this thread will read from. Recall that the supremum index
        -- is exclusive.
        Operands Int
i0  <- case Direction
dir of
                 Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
inf Operands Int
tid'
                 Direction
RightToLeft -> do Operands Int
x <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
tid'
                                   Operands Int
y <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
x (Int -> Operands Int
liftInt Int
1)
                                   Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
y

        -- The index that this thread will write to. This is just shifted along
        -- by one to make room for the initial element.
        Operands Int
j0  <- case Direction
dir of
                 Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i0 (Int -> Operands Int
liftInt Int
1)
                 Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i0 (Int -> Operands Int
liftInt Int
1)

        -- Evaluate the initial element. Store it into the carry-in slot as well
        -- as to the array as the first element. This is always valid because if
        -- the input array is empty then we will be evaluating via mkScan'Fill.
        CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
          Operands e
z <- IRExp PTX aenv e
seed
          IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt   IRArray (Array (sh, Int) e)
arrOut Operands Int
i0            Operands e
z
          IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry  (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
Operands e
z

        Operands Int32
bd  <- CodeGen PTX (Operands Int32)
blockDim
        Operands Int
bd' <- Operands Int32 -> CodeGen PTX (Operands Int)
int Operands Int32
bd
        let next :: Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
ix = case Direction
dir of
                        Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
ix Operands Int
bd'
                        Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
ix Operands Int
bd'

        -- Now, threads iterate over the elements along the innermost dimension.
        -- At each iteration the first thread incorporates the carry-in value
        -- from the previous step.
        --
        -- The index tracks how many elements remain for the thread block, since
        -- indices i* and j* are local to each thread
        Operands Int
n0  <- NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sup Operands Int
inf
        CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ())
-> CodeGen PTX (Operands (Tup3 Int Int Int)) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ TypeR (Tup3 Int Int Int)
-> (Operands (Tup3 Int Int Int) -> CodeGen PTX (Operands Bool))
-> (Operands (Tup3 Int Int Int)
    -> CodeGen PTX (Operands (Tup3 Int Int Int)))
-> Operands (Tup3 Int Int Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while
          (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ScalarType () -> TypeR Int -> TupR ScalarType DIM1
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType DIM1 -> TypeR Int -> TupR ScalarType (DIM1, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType (DIM1, Int)
-> TypeR Int -> TypeR (Tup3 Int Int Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt)
          (\(Operands (Tup3 Int Int Int) -> Operands Int
forall a b c. Operands (Tup3 a b c) -> Operands a
A.fst3   -> Operands Int
n)       -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n (Int -> Operands Int
liftInt Int
0))
          (\(Operands (Tup3 Int Int Int)
-> (Operands Int, Operands Int, Operands Int)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip -> (Operands Int
n,Operands Int
i,Operands Int
j)) -> do

            -- Wait for threads to catch up to ensure the carry-in value from
            -- the last iteration has been updated
            CodeGen PTX ()
__syncthreads

            -- If all threads in the block will participate this round we can
            -- avoid (almost) all bounds checks.
            Operands ()
_ <- if (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
n Operands Int
bd')
                    -- All threads participate. No bounds checks required but
                    -- the last thread needs to update the carry-in value.
                    then do
                      Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i
                      Operands e
y <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
                              then do
                                Operands e
c <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)
                                case Direction
dir of
                                  Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
c Operands e
x
                                  Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x Operands e
Operands e
c
                              else
                                Operands e -> IRExp PTX aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x
                      Operands e
z <- Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
forall a. Maybe a
Nothing Operands e
y

                      -- Write results to the output array. Note that if we
                      -- align directly on the boundary of the array this is not
                      -- valid for the last thread.
                      case Direction
dir of
                        Direction
LeftToRight -> CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
j Operands Int
sup) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
j Operands e
Operands e
z
                        Direction
RightToLeft -> CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
j Operands Int
inf) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
j Operands e
Operands e
z

                      -- Last thread of the block also saves its result as the
                      -- carry-in value
                      Operands Int32
bd1 <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
bd (Int32 -> Operands Int32
liftInt32 Int32
1)
                      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
bd1) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
                        IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
z

                      Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

                    -- Only threads that are in bounds can participate. This is
                    -- the last iteration of the loop. The last active thread
                    -- still needs to store its value into the carry-in slot.
                    --
                    -- Note that all threads must call the block-wide scan.
                    -- SEE: [Synchronisation problems with SM_70 and greater]
                    else do
                      Operands e
x <- if (TypeR e
tp, SingleType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
tid' Operands Int
n)
                              then do
                                Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IRExp PTX aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array (sh, Int) e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array (sh, Int) e)
arrIn) Operands Int
i
                                Operands e
y <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0))
                                        then do
                                          Operands e
c <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)
                                          case Direction
dir of
                                            Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
Operands e
c Operands e
x
                                            Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp PTX aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x Operands e
Operands e
c
                                        else
                                          Operands e -> IRExp PTX aenv e
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x
                                Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
y
                              else
                                let
                                    go :: TypeR a -> Operands a
                                    go :: TypeR a -> Operands a
go TypeR a
TupRunit       = Operands a
Operands ()
OP_Unit
                                    go (TupRpair TupR ScalarType a1
a TupR ScalarType b
b) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (TupR ScalarType a1 -> Operands a1
forall a. TypeR a -> Operands a
go TupR ScalarType a1
a) (TupR ScalarType b -> Operands b
forall a. TypeR a -> Operands a
go TupR ScalarType b
b)
                                    go (TupRsingle ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t (ScalarType a -> Operand a
forall a. ScalarType a -> Operand a
undef ScalarType a
t)
                                in
                                Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands e -> CodeGen PTX (Operands e))
-> Operands e -> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ TypeR e -> Operands e
forall a. TypeR a -> Operands a
go TypeR e
tp

                      Operands Int32
l <- Operands Int -> CodeGen PTX (Operands Int32)
i32 Operands Int
n
                      Operands e
y <- Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine (Operands Int32 -> Maybe (Operands Int32)
forall a. a -> Maybe a
Just Operands Int32
l) Operands e
x

                      Operands Int32
m <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
l (Int32 -> Operands Int32
liftInt32 Int32
1)
                      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
m) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt   IRArray (Array (sh, Int) e)
arrOut Operands Int
j            Operands e
Operands e
y
                      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid Operands Int32
m) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0) Operands e
y

                      Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TupR ScalarType () -> () -> Operands ()
forall a. TypeR a -> a -> Operands a
lift TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit ())

            Operands Int
-> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip (Operands Int
 -> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int))
-> CodeGen PTX (Operands Int)
-> CodeGen
     PTX (Operands Int -> Operands Int -> Operands (Tup3 Int Int Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NumType Int
-> Operands Int -> Operands Int -> CodeGen PTX (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
n Operands Int
bd' CodeGen
  PTX (Operands Int -> Operands Int -> Operands (Tup3 Int Int Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands Int -> Operands (Tup3 Int Int Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
i CodeGen PTX (Operands Int -> Operands (Tup3 Int Int Int))
-> CodeGen PTX (Operands Int)
-> CodeGen PTX (Operands (Tup3 Int Int Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen PTX (Operands Int)
next Operands Int
j)
          (Operands Int
-> Operands Int -> Operands Int -> Operands (Tup3 Int Int Int)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
n0 Operands Int
i0 Operands Int
j0)

        -- Wait for the carry-in value to be updated
        CodeGen PTX ()
__syncthreads

        -- Store the carry-in value to the separate final results array
        CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
tid (Int32 -> Operands Int32
liftInt32 Int32
0)) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$
          IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrSum Operands Int
seg (Operands e -> CodeGen PTX ())
-> IRExp PTX aenv e -> CodeGen PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
carry (Int32 -> Operands Int32
liftInt32 Int32
0)

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_



-- Parallel scan, auxiliary
--
-- If this is an exclusive scan of an empty array, we just fill the result with
-- the seed element.
--
mkScanFill
    :: Gamma aenv
    -> ArrayR (Array sh e)
    -> IRExp PTX aenv e
    -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkScanFill :: Gamma aenv
-> ArrayR (Array sh e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkScanFill Gamma aenv
aenv ArrayR (Array sh e)
repr IRExp PTX aenv e
seed =
  Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkGenerate Gamma aenv
aenv ArrayR (Array sh e)
repr ((Operands sh -> IRExp PTX aenv e) -> IRFun1 PTX aenv (sh -> e)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 (IRExp PTX aenv e -> Operands sh -> IRExp PTX aenv e
forall a b. a -> b -> a
const IRExp PTX aenv e
seed))

mkScan'Fill
    :: Gamma aenv
    -> ArrayR (Array (sh, Int) e)
    -> IRExp PTX aenv e
    -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill :: Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp PTX aenv e
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill Gamma aenv
aenv ArrayR (Array (sh, Int) e)
repr IRExp PTX aenv e
seed =
  IROpenAcc PTX aenv (Array sh e)
-> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)
Safe.coerce (IROpenAcc PTX aenv (Array sh e)
 -> IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
-> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv sh e.
Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkGenerate Gamma aenv
aenv (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
repr) ((Operands sh -> IRExp PTX aenv e) -> IRFun1 PTX aenv (sh -> e)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 (IRExp PTX aenv e -> Operands sh -> IRExp PTX aenv e
forall a b. a -> b -> a
const IRExp PTX aenv e
seed))


-- Block wide scan
-- ---------------

-- Efficient block-wide (inclusive) scan using the specified operator.
--
-- Each block requires (#warps * (1 + 1.5*warp size)) elements of dynamically
-- allocated shared memory.
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/block/specializations/block_scan_warp_scans.cuh
--
-- NOTE: [Synchronisation problems with SM_70 and greater]
--
-- This operation uses thread synchronisation. When calling this operation, it
-- is important that all active (that is, non-exited) threads of the thread
-- block participate. It seems that sm_70+ (devices with independent thread
-- scheduling) are stricter about the requirement that all non-existed threads
-- participate in every barrier.
--
-- See: https://github.com/AccelerateHS/accelerate/issues/436
--
scanBlockSMem
    :: forall aenv e.
       Direction
    -> DeviceProperties                             -- ^ properties of the target device
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> Maybe (Operands Int32)                       -- ^ number of valid elements (may be less than block size)
    -> Operands e                                   -- ^ calling thread's input element
    -> CodeGen PTX (Operands e)
scanBlockSMem :: Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (Operands Int32)
-> Operands e
-> CodeGen PTX (Operands e)
scanBlockSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine Maybe (Operands Int32)
nelem = Operands e -> CodeGen PTX (Operands e)
warpScan (Operands e -> CodeGen PTX (Operands e))
-> (Operands e -> CodeGen PTX (Operands e))
-> Operands e
-> CodeGen PTX (Operands e)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Operands e -> CodeGen PTX (Operands e)
warpPrefix
  where
    int32 :: Integral a => a -> Operands Int32
    int32 :: a -> Operands Int32
int32 = Int32 -> Operands Int32
liftInt32 (Int32 -> Operands Int32) -> (a -> Int32) -> a -> Operands Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral

    -- Temporary storage required for each warp
    warp_smem_elems :: Int
warp_smem_elems = DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2)
    warp_smem_bytes :: Int
warp_smem_bytes = Int
warp_smem_elems  Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp

    -- Step 1: Scan in every warp
    warpScan :: Operands e -> CodeGen PTX (Operands e)
    warpScan :: Operands e -> CodeGen PTX (Operands e)
warpScan Operands e
input = do
      -- Allocate (1.5 * warpSize) elements of shared memory for each warp
      -- (individually addressable by each warp)
      Operands Int32
wid   <- CodeGen PTX (Operands Int32)
warpId
      Operands Int32
skip  <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
wid (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
warp_smem_bytes)
      IRArray (Vector e)
smem  <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
warp_smem_elems) Operands Int32
skip
      Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Operands e
-> CodeGen PTX (Operands e)
forall aenv e.
Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Operands e
-> CodeGen PTX (Operands e)
scanWarpSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem Operands e
input

    -- Step 2: Collect the aggregate results of each warp to compute the prefix
    -- values for each warp and combine with the partial result to compute each
    -- thread's final value.
    warpPrefix :: Operands e -> CodeGen PTX (Operands e)
    warpPrefix :: Operands e -> CodeGen PTX (Operands e)
warpPrefix Operands e
input = do
      -- Allocate #warps elements of shared memory
      Operands Int32
bd    <- CodeGen PTX (Operands Int32)
blockDim
      Operands Int32
warps <- IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
bd (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev))
      Operands Int32
skip  <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
warps (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 Int
warp_smem_bytes)
      IRArray (Vector e)
smem  <- TypeR e
-> IntegralType Int32
-> Operands Int32
-> Operands Int32
-> CodeGen PTX (IRArray (Vector e))
forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType Int32
TypeInt32 Operands Int32
warps Operands Int32
skip

      -- Share warp aggregates
      Operands Int32
wid   <- CodeGen PTX (Operands Int32)
warpId
      Operands Int32
lane  <- CodeGen PTX (Operands Int32)
laneId
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))) (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
wid Operands e
input

      -- Wait for each warp to finish its local scan and share the aggregate
      CodeGen PTX ()
__syncthreads

      -- Compute the prefix value for this warp and add to the partial result.
      -- This step is not required for the first warp, which has no carry-in.
      if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
wid (Int32 -> Operands Int32
liftInt32 Int32
0))
        then Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
input
        else do
          -- Every thread sequentially scans the warp aggregates to compute
          -- their prefix value. We do this sequentially, but could also have
          -- warp 0 do it cooperatively if we limit thread block sizes to
          -- (warp size ^ 2).
          Operands Int32
steps  <- case Maybe (Operands Int32)
nelem of
                      Maybe (Operands Int32)
Nothing -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
wid
                      Just Operands Int32
n  -> SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.min SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
wid (Operands Int32 -> CodeGen PTX (Operands Int32))
-> CodeGen PTX (Operands Int32) -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
n (Int -> Operands Int32
forall a. Integral a => a -> Operands Int32
int32 (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev))

          Operands e
p0     <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem (Int32 -> Operands Int32
liftInt32 Int32
0)
          Operands e
prefix <- TypeR e
-> Operands Int32
-> Operands Int32
-> Operands Int32
-> Operands e
-> (Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall i a arch.
IsNum i =>
TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR e
tp (Int32 -> Operands Int32
liftInt32 Int32
1) (Int32 -> Operands Int32
liftInt32 Int32
1) Operands Int32
steps Operands e
p0 ((Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
 -> CodeGen PTX (Operands e))
-> (Operands Int32 -> Operands e -> CodeGen PTX (Operands e))
-> CodeGen PTX (Operands e)
forall a b. (a -> b) -> a -> b
$ \Operands Int32
step Operands e
x -> do
                      Operands e
y <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
step
                      case Direction
dir of
                        Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x Operands e
y
                        Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
y Operands e
x

          case Direction
dir of
            Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
prefix Operands e
input
            Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
input Operands e
prefix


-- Warp-wide scan
-- --------------

-- Efficient warp-wide (inclusive) scan using the specified operator.
--
-- Each warp requires 48 (1.5 x warp size) elements of shared memory. The
-- routine assumes that it is allocated individually per-warp (i.e. can be
-- indexed in the range [0, warp size)).
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/warp/specializations/warp_scan_smem.cuh
--
scanWarpSMem
    :: forall aenv e.
       Direction
    -> DeviceProperties                             -- ^ properties of the target device
    -> TypeR e
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> IRArray (Vector e)                           -- ^ temporary storage array in shared memory (1.5 x warp size elements)
    -> Operands e                                   -- ^ calling thread's input element
    -> CodeGen PTX (Operands e)
scanWarpSMem :: Direction
-> DeviceProperties
-> TypeR e
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Operands e
-> CodeGen PTX (Operands e)
scanWarpSMem Direction
dir DeviceProperties
dev TypeR e
tp IRFun2 PTX aenv (e -> e -> e)
combine IRArray (Vector e)
smem = Int -> Operands e -> CodeGen PTX (Operands e)
scan Int
0
  where
    log2 :: Double -> Double
    log2 :: Double -> Double
log2 = Double -> Double -> Double
forall a. Floating a => a -> a -> a
P.logBase Double
2

    -- Number of steps required to scan warp
    steps :: Int
steps     = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
P.floor (Double -> Double
log2 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev)))
    halfWarp :: Int32
halfWarp  = Int -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.quot` Int
2)

    -- Unfold the scan as a recursive code generation function
    scan :: Int -> Operands e -> CodeGen PTX (Operands e)
    scan :: Int -> Operands e -> CodeGen PTX (Operands e)
scan Int
step Operands e
x
      | Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
steps = Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x
      | Bool
otherwise     = do
          let offset :: Operands Int32
offset = Int32 -> Operands Int32
liftInt32 (Int32
1 Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`P.shiftL` Int
step)

          -- share partial result through shared memory buffer
          Operands Int32
lane <- CodeGen PTX (Operands Int32)
laneId
          Operands Int32
i    <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
lane (Int32 -> Operands Int32
liftInt32 Int32
halfWarp)
          IntegralType Int32
-> IRArray (Vector e)
-> Operands Int32
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
i Operands e
x

          CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp

          -- update partial result if in range
          Operands e
x'   <- if (TypeR e
tp, SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
lane Operands Int32
offset)
                    then do
                      Operands Int32
i' <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
i Operands Int32
offset    -- lane + HALF_WARP - offset
                      Operands e
x' <- IntegralType Int32
-> IRArray (Vector e) -> Operands Int32 -> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int32
TypeInt32 IRArray (Vector e)
smem Operands Int32
i'
                      case Direction
dir of
                        Direction
LeftToRight -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x' Operands e
x
                        Direction
RightToLeft -> IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen PTX (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x Operands e
x'

                    else
                      Operands e -> CodeGen PTX (Operands e)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands e
x

          CodeGen PTX ()
HasCallStack => CodeGen PTX ()
__syncwarp

          Int -> Operands e -> CodeGen PTX (Operands e)
scan (Int
stepInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Operands e
x'


-- Utilities
-- ---------

i32 :: Operands Int -> CodeGen PTX (Operands Int32)
i32 :: Operands Int -> CodeGen PTX (Operands Int32)
i32 = IntegralType Int
-> NumType Int32 -> Operands Int -> CodeGen PTX (Operands Int32)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType NumType Int32
forall a. IsNum a => NumType a
numType

int :: Operands Int32 -> CodeGen PTX (Operands Int)
int :: Operands Int32 -> CodeGen PTX (Operands Int)
int = IntegralType Int32
-> NumType Int -> Operands Int32 -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType NumType Int
forall a. IsNum a => NumType a
numType