{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen.GPU.Base
  ( KernelConstants (..),
    kernelGlobalThreadId,
    kernelLocalThreadId,
    kernelBlockId,
    threadOperations,
    keyWithEntryPoint,
    CallKernelGen,
    InKernelGen,
    Locks (..),
    HostEnv (..),
    Target (..),
    KernelEnv (..),
    blockReduce,
    blockScan,
    blockLoop,
    isActive,
    sKernel,
    sKernelThread,
    KernelAttrs (..),
    defKernelAttrs,
    lvlKernelAttrs,
    allocLocal,
    compileThreadResult,
    virtualiseBlocks,
    kernelLoop,
    blockCoverSpace,
    fenceForArrays,
    updateAcc,
    genZeroes,
    isPrimParam,
    kernelConstToExp,
    getChunkSize,
    getSize,

    -- * Host-level bulk operations
    sReplicate,
    sIota,

    -- * Atomics
    AtomicBinOp,
    atomicUpdateLocking,
    Locking (..),
    AtomicUpdate (..),
    DoAtomicUpdate,
    writeAtomic,
  )
where

import Control.Monad
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (dropLast, nubOrd, splitFromEnd)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | Which target are we ultimately generating code for?  While most
-- of the kernels code is the same, there are some cases where we
-- generate special code based on the ultimate low-level API we are
-- targeting.
data Target = CUDA | OpenCL | HIP

-- | Information about the locks available for accumulators.
data Locks = Locks
  { Locks -> VName
locksArray :: VName,
    Locks -> Int
locksCount :: Int
  }

data HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp,
    HostEnv -> Target
hostTarget :: Target,
    HostEnv -> Map VName Locks
hostLocks :: M.Map VName Locks
  }

data KernelEnv = KernelEnv
  { KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp,
    KernelEnv -> KernelConstants
kernelConstants :: KernelConstants,
    KernelEnv -> Map VName Locks
kernelLocks :: M.Map VName Locks
  }

type CallKernelGen = ImpM GPUMem HostEnv Imp.HostOp

type InKernelGen = ImpM GPUMem KernelEnv Imp.KernelOp

data KernelConstants = KernelConstants
  { KernelConstants -> TV Int32
kernelGlobalThreadIdVar :: TV Int32,
    KernelConstants -> TV Int32
kernelLocalThreadIdVar :: TV Int32,
    KernelConstants -> TV Int32
kernelBlockIdVar :: TV Int32,
    KernelConstants -> Count NumBlocks SubExp
kernelNumBlocksCount :: Count NumBlocks SubExp,
    KernelConstants -> Count BlockSize SubExp
kernelBlockSizeCount :: Count BlockSize SubExp,
    KernelConstants -> TPrimExp Int64 VName
kernelNumBlocks :: Imp.TExp Int64,
    KernelConstants -> TPrimExp Int64 VName
kernelBlockSize :: Imp.TExp Int64,
    KernelConstants -> TExp Int32
kernelNumThreads :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelWaveSize :: Imp.TExp Int32,
    -- | A mapping from dimensions of nested SegOps to already
    -- computed local thread IDs.  Only valid in non-virtualised case.
    KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap :: M.Map [SubExp] [Imp.TExp Int32],
    -- | Mapping from dimensions of nested SegOps to how many
    -- iterations the virtualisation loop needs.
    KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

kernelGlobalThreadId, kernelLocalThreadId, kernelBlockId :: KernelConstants -> Imp.TExp Int32
kernelGlobalThreadId :: KernelConstants -> TExp Int32
kernelGlobalThreadId = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp (TV Int32 -> TExp Int32)
-> (KernelConstants -> TV Int32) -> KernelConstants -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TV Int32
kernelGlobalThreadIdVar
kernelLocalThreadId :: KernelConstants -> TExp Int32
kernelLocalThreadId = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp (TV Int32 -> TExp Int32)
-> (KernelConstants -> TV Int32) -> KernelConstants -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TV Int32
kernelLocalThreadIdVar
kernelBlockId :: KernelConstants -> TExp Int32
kernelBlockId = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp (TV Int32 -> TExp Int32)
-> (KernelConstants -> TV Int32) -> KernelConstants -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TV Int32
kernelBlockIdVar

keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
  [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char] -> (Name -> [Char]) -> Maybe Name -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" (([Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
".") ([Char] -> [Char]) -> (Name -> [Char]) -> Name -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> [Char]
nameToString) Maybe Name
fname [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
nameToString Name
key

allocLocal :: AllocCompiler GPUMem r Imp.KernelOp
allocLocal :: forall r. AllocCompiler GPUMem r KernelOp
allocLocal VName
mem Count Bytes (TPrimExp Int64 VName)
size =
  KernelOp -> ImpM GPUMem r KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem r KernelOp ())
-> KernelOp -> ImpM GPUMem r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TPrimExp Int64 VName) -> KernelOp
Imp.SharedAlloc VName
mem Count Bytes (TPrimExp Int64 VName)
size

threadAlloc ::
  Pat LetDecMem ->
  SubExp ->
  Space ->
  InKernelGen ()
threadAlloc :: Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
threadAlloc (Pat [PatElem LetDecMem
_]) SubExp
_ ScalarSpace {} =
  -- Handled by the declaration of the memory block, which is then
  -- translated to an actual scalar variable during C code generation.
  () -> InKernelGen ()
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
threadAlloc (Pat [PatElem LetDecMem
mem]) SubExp
_ Space
_ =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerLimitationS ([Char] -> InKernelGen ()) -> [Char] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate memory block " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PatElem LetDecMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PatElem LetDecMem
mem [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in kernel thread."
threadAlloc Pat LetDecMem
dest SubExp
_ Space
_ =
  [Char] -> InKernelGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> InKernelGen ()) -> [Char] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid target for in-kernel allocation: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> [Char]
forall a. Show a => a -> [Char]
show Pat LetDecMem
dest

updateAcc :: Safety -> VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc :: Safety -> VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
vs = Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  -- See the ImpGen implementation of UpdateAcc for general notes.
  let is' :: [TPrimExp Int64 VName]
is' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
is
  (VName
c, Space
space, [VName]
arrs, [TPrimExp Int64 VName]
dims, Maybe (Lambda GPUMem)
op) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, [VName], [TPrimExp Int64 VName],
      Maybe (Lambda GPUMem))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], [TPrimExp Int64 VName], Maybe (Lambda rep))
lookupAcc VName
acc [TPrimExp Int64 VName]
is'
  let boundsCheck :: InKernelGen () -> InKernelGen ()
boundsCheck =
        case Safety
safety of
          Safety
Safe -> TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
is')) [TPrimExp Int64 VName]
dims)
          Safety
_ -> InKernelGen () -> InKernelGen ()
forall a. a -> a
id
  InKernelGen () -> InKernelGen ()
boundsCheck (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda GPUMem)
op of
      Maybe (Lambda GPUMem)
Nothing ->
        [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
v []
      Just Lambda GPUMem
lam -> do
        [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        let ([VName]
_x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) -> VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []
        AtomicBinOp
atomics <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
        case AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
lam of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> do
            Maybe Locks
c_locks <- VName -> Map VName Locks -> Maybe Locks
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
c (Map VName Locks -> Maybe Locks)
-> (KernelEnv -> Map VName Locks) -> KernelEnv -> Maybe Locks
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> Map VName Locks
kernelLocks (KernelEnv -> Maybe Locks)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Maybe Locks)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
            case Maybe Locks
c_locks of
              Just (Locks VName
locks Int
num_locks) -> do
                let locking :: Locking
locking =
                      VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]) -> Locking)
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]) -> Locking
forall a b. (a -> b) -> a -> b
$
                        TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims
                Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
              Maybe Locks
Nothing ->
                [Char] -> InKernelGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> InKernelGen ()) -> [Char] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Missing locks for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
acc

-- | Generate a constant device array of 32-bit integer zeroes with
-- the given number of elements.  Initialised with a replicate.
genZeroes :: String -> Int -> CallKernelGen VName
genZeroes :: [Char] -> Int -> CallKernelGen VName
genZeroes [Char]
desc Int
n = ImpM GPUMem HostEnv HostOp (Names, VName) -> CallKernelGen VName
forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants (ImpM GPUMem HostEnv HostOp (Names, VName) -> CallKernelGen VName)
-> ImpM GPUMem HostEnv HostOp (Names, VName) -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$ do
  VName
counters_mem <- [Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> CallKernelGen VName
forall rep r op.
[Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc ([Char]
desc [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_mem") (Count Bytes (TPrimExp Int64 VName)
4 Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
forall a. Num a => a -> a -> a
* Int -> Count Bytes (TPrimExp Int64 VName)
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) ([Char] -> Space
Space [Char]
"device")
  let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)]
  VName
counters <- [Char] -> PrimType -> Shape -> VName -> CallKernelGen VName
forall rep r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
desc PrimType
int32 Shape
shape VName
counters_mem
  VName -> SubExp -> CallKernelGen ()
sReplicate VName
counters (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
  (Names, VName) -> ImpM GPUMem HostEnv HostOp (Names, VName)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Names
namesFromList [VName
counters_mem], VName
counters)

compileThreadExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileThreadExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) [] SubExp
se []
compileThreadExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase Shape NoUniqueness
_)) =
  [(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
dest) [Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileThreadExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
vs)) =
  Safety -> VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
vs
compileThreadExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  ExpCompiler GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

-- | Assign iterations of a for-loop to all threads in the kernel.
-- The passed-in function is invoked with the (symbolic) iteration.
-- The body must contain thread-level code.  For multidimensional
-- loops, use 'blockCoverSpace'.
kernelLoop ::
  (IntExp t) =>
  Imp.TExp t ->
  Imp.TExp t ->
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
kernelLoop :: forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp t
tid TExp t
num_threads TExp t
n TExp t -> InKernelGen ()
f =
  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    if TExp t
n TExp t -> TExp t -> Bool
forall a. Eq a => a -> a -> Bool
== TExp t
num_threads
      then TExp t -> InKernelGen ()
f TExp t
tid
      else do
        TExp t
num_chunks <- [Char] -> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" (TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t))
-> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall e. IntegralExp e => e -> e -> e
`divUp` TExp t
num_threads
        [Char] -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp t
num_chunks ((TExp t -> InKernelGen ()) -> InKernelGen ())
-> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp t
chunk_i -> do
          TExp t
i <- [Char] -> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" (TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t))
-> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
chunk_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
num_threads TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
tid
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp t
i TExp t -> TExp t -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp t
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp t -> InKernelGen ()
f TExp t
i

-- | Assign iterations of a for-loop to threads in the threadblock.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'blockCoverSpace'.
blockLoop ::
  (IntExp t) =>
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
blockLoop :: forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
blockLoop TExp t
n TExp t -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop
    (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp t -> TExp t
forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants TPrimExp Int64 VName -> TExp t -> TExp t
forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    TExp t
n
    TExp t -> InKernelGen ()
f

-- | Iterate collectively though a multidimensional space, such that
-- all threads in the block participate.  The passed-in function is
-- invoked with a (symbolic) point in the index space.
blockCoverSpace ::
  (IntExp t) =>
  [Imp.TExp t] ->
  ([Imp.TExp t] -> InKernelGen ()) ->
  InKernelGen ()
blockCoverSpace :: forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
blockCoverSpace [TExp t]
ds [TExp t] -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let tblock_size :: TPrimExp Int64 VName
tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
  case Int -> [TExp t] -> ([TExp t], [TExp t])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
1 [TExp t]
ds of
    -- Optimise the case where the inner dimension of the space is
    -- equal to the block size.
    ([TExp t]
ds', [TExp t
last_d])
      | TExp t
last_d TExp t -> TExp t -> Bool
forall a. Eq a => a -> a -> Bool
== (TPrimExp Int64 VName
tblock_size TPrimExp Int64 VName -> TExp t -> TExp t
forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d) -> do
          let ltid :: TExp t
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp t -> TExp t
forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d
          [TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace [TExp t]
ds' (([TExp t] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp t]
ds_is ->
            [TExp t] -> InKernelGen ()
f ([TExp t] -> InKernelGen ()) -> [TExp t] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp t]
ds_is [TExp t] -> [TExp t] -> [TExp t]
forall a. [a] -> [a] -> [a]
++ [TExp t
ltid]
    ([TExp t], [TExp t])
_ ->
      TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
blockLoop ([TExp t] -> TExp t
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp t]
ds) ((TExp t -> InKernelGen ()) -> InKernelGen ())
-> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp t] -> InKernelGen ()
f ([TExp t] -> InKernelGen ())
-> (TExp t -> [TExp t]) -> TExp t -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp t] -> TExp t -> [TExp t]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp t]
ds

-- Which fence do we need to protect shared access to this memory space?
fenceForSpace :: Space -> Imp.Fence
fenceForSpace :: Space -> Fence
fenceForSpace (Space [Char]
"shared") = Fence
Imp.FenceLocal
fenceForSpace Space
_ = Fence
Imp.FenceGlobal

-- | If we are touching these arrays, which kind of fence do we need?
fenceForArrays :: [VName] -> InKernelGen Imp.Fence
fenceForArrays :: [VName] -> InKernelGen Fence
fenceForArrays = ([Fence] -> Fence)
-> ImpM GPUMem KernelEnv KernelOp [Fence] -> InKernelGen Fence
forall a b.
(a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Fence -> Fence -> Fence) -> Fence -> [Fence] -> Fence
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Fence -> Fence -> Fence
forall a. Ord a => a -> a -> a
max Fence
Imp.FenceLocal) (ImpM GPUMem KernelEnv KernelOp [Fence] -> InKernelGen Fence)
-> ([VName] -> ImpM GPUMem KernelEnv KernelOp [Fence])
-> [VName]
-> InKernelGen Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> InKernelGen Fence)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [Fence]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> InKernelGen Fence
forall {rep} {r} {op}. VName -> ImpM rep r op Fence
need
  where
    need :: VName -> ImpM rep r op Fence
need VName
arr =
      (MemEntry -> Fence)
-> ImpM rep r op MemEntry -> ImpM rep r op Fence
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Fence
fenceForSpace (Space -> Fence) -> (MemEntry -> Space) -> MemEntry -> Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace)
        (ImpM rep r op MemEntry -> ImpM rep r op Fence)
-> (ArrayEntry -> ImpM rep r op MemEntry)
-> ArrayEntry
-> ImpM rep r op Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
        (VName -> ImpM rep r op MemEntry)
-> (ArrayEntry -> VName) -> ArrayEntry -> ImpM rep r op MemEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemLoc -> VName
memLocName
        (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
        (ArrayEntry -> ImpM rep r op Fence)
-> ImpM rep r op ArrayEntry -> ImpM rep r op Fence
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr

isPrimParam :: (Typed p) => Param p -> Bool
isPrimParam :: forall p. Typed p => Param p -> Bool
isPrimParam = TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> (Param p -> TypeBase Shape NoUniqueness) -> Param p -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param p -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType

kernelConstToExp :: Imp.KernelConstExp -> CallKernelGen Imp.Exp
kernelConstToExp :: KernelConstExp -> CallKernelGen Exp
kernelConstToExp = (KernelConst -> CallKernelGen VName)
-> KernelConstExp -> CallKernelGen Exp
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> PrimExp a -> f (PrimExp b)
traverse KernelConst -> CallKernelGen VName
forall {rep} {r}. KernelConst -> ImpM rep r HostOp VName
f
  where
    f :: KernelConst -> ImpM rep r HostOp VName
f (Imp.SizeMaxConst SizeClass
c) = do
      VName
v <- [Char] -> PrimType -> ImpM rep r HostOp VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS (SizeClass -> [Char]
forall a. Pretty a => a -> [Char]
prettyString SizeClass
c) PrimType
int64
      HostOp -> ImpM rep r HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM rep r HostOp ()) -> HostOp -> ImpM rep r HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
v SizeClass
c
      VName -> ImpM rep r HostOp VName
forall a. a -> ImpM rep r HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
    f (Imp.SizeConst Name
k SizeClass
c) = do
      VName
v <- [Char] -> PrimType -> ImpM rep r HostOp VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS (Name -> [Char]
nameToString Name
k) PrimType
int64
      HostOp -> ImpM rep r HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM rep r HostOp ()) -> HostOp -> ImpM rep r HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
v Name
k SizeClass
c
      VName -> ImpM rep r HostOp VName
forall a. a -> ImpM rep r HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

-- | Given available register and a list of parameter types, compute
-- the largest available chunk size given the parameters for which we
-- want chunking and the available resources. Used in
-- 'SegScan.SinglePass.compileSegScan', and 'SegRed.compileSegRed'
-- (with primitive non-commutative operators only).
getChunkSize :: [Type] -> Imp.KernelConstExp
getChunkSize :: [TypeBase Shape NoUniqueness] -> KernelConstExp
getChunkSize [TypeBase Shape NoUniqueness]
types = do
  let max_tblock_size :: KernelConst
max_tblock_size = SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
SizeThreadBlock
      max_block_mem :: KernelConst
max_block_mem = SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
SizeSharedMemory
      max_block_reg :: KernelConst
max_block_reg = SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
SizeRegisters
      k_mem :: TPrimExp Int64 KernelConst
k_mem = KernelConst -> TPrimExp Int64 KernelConst
forall a. a -> TPrimExp Int64 a
le64 KernelConst
max_block_mem TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall e. IntegralExp e => e -> e -> e
`quot` KernelConst -> TPrimExp Int64 KernelConst
forall a. a -> TPrimExp Int64 a
le64 KernelConst
max_tblock_size
      k_reg :: TPrimExp Int64 KernelConst
k_reg = KernelConst -> TPrimExp Int64 KernelConst
forall a. a -> TPrimExp Int64 a
le64 KernelConst
max_block_reg TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall e. IntegralExp e => e -> e -> e
`quot` KernelConst -> TPrimExp Int64 KernelConst
forall a. a -> TPrimExp Int64 a
le64 KernelConst
max_tblock_size
      types' :: [PrimType]
types' = (TypeBase Shape NoUniqueness -> PrimType)
-> [TypeBase Shape NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType ([TypeBase Shape NoUniqueness] -> [PrimType])
-> [TypeBase Shape NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. (a -> Bool) -> [a] -> [a]
filter TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
types
      sizes :: [TPrimExp Int64 KernelConst]
sizes = (PrimType -> TPrimExp Int64 KernelConst)
-> [PrimType] -> [TPrimExp Int64 KernelConst]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 KernelConst
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types'

      sum_sizes :: TPrimExp Int64 KernelConst
sum_sizes = [TPrimExp Int64 KernelConst] -> TPrimExp Int64 KernelConst
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [TPrimExp Int64 KernelConst]
sizes
      sum_sizes' :: TPrimExp Int64 KernelConst
sum_sizes' = [TPrimExp Int64 KernelConst] -> TPrimExp Int64 KernelConst
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((PrimType -> TPrimExp Int64 KernelConst)
-> [PrimType] -> [TPrimExp Int64 KernelConst]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 KernelConst
4 (TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst)
-> (PrimType -> TPrimExp Int64 KernelConst)
-> PrimType
-> TPrimExp Int64 KernelConst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> TPrimExp Int64 KernelConst
forall a. Num a => PrimType -> a
primByteSize) [PrimType]
types') TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 KernelConst
4
      max_size :: TPrimExp Int64 KernelConst
max_size = [TPrimExp Int64 KernelConst] -> TPrimExp Int64 KernelConst
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [TPrimExp Int64 KernelConst]
sizes

      mem_constraint :: TPrimExp Int64 KernelConst
mem_constraint = TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall a. Ord a => a -> a -> a
max TPrimExp Int64 KernelConst
k_mem TPrimExp Int64 KernelConst
sum_sizes TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 KernelConst
max_size
      reg_constraint :: TPrimExp Int64 KernelConst
reg_constraint = (TPrimExp Int64 KernelConst
k_reg TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall a. Num a => a -> a -> a
- TPrimExp Int64 KernelConst
1 TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall a. Num a => a -> a -> a
- TPrimExp Int64 KernelConst
sum_sizes') TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall e. IntegralExp e => e -> e -> e
`quot` (TPrimExp Int64 KernelConst
2 TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall a. Num a => a -> a -> a
* TPrimExp Int64 KernelConst
sum_sizes')
  TPrimExp Int64 KernelConst -> KernelConstExp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 KernelConst -> KernelConstExp)
-> TPrimExp Int64 KernelConst -> KernelConstExp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 KernelConst
1 (TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst)
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 KernelConst
-> TPrimExp Int64 KernelConst -> TPrimExp Int64 KernelConst
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TPrimExp Int64 KernelConst
mem_constraint TPrimExp Int64 KernelConst
reg_constraint

inChunkScan ::
  KernelConstants ->
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  Imp.TExp Bool ->
  [VName] ->
  InKernelGen () ->
  Lambda GPUMem ->
  InKernelGen ()
inChunkScan :: KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inChunkScan KernelConstants
constants Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TExp Int32
lockstep_width TExp Int32
block_size TExp Bool
active [VName]
arrs InKernelGen ()
barrier Lambda GPUMem
scan_lam = InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  TV Int32
skip_threads <- [Char] -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"skip_threads"
  let actual_params :: [LParam GPUMem]
actual_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
      ([Param LetDecMem]
x_params, [Param LetDecMem]
y_params) =
        Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param LetDecMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
[Param LetDecMem]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
[Param LetDecMem]
actual_params
      y_to_x :: InKernelGen ()
y_to_x =
        [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y)) []

  -- Set initial y values
  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read input for in-block scan" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readInitial [Param LetDecMem]
y_params [VName]
arrs
      -- Since the final result is expected to be in x_params, we may
      -- need to copy it there for the first thread in the block.
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) InKernelGen ()
y_to_x

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier

  let op_to_x :: TExp Bool -> InKernelGen ()
op_to_x TExp Bool
in_block_thread_active
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations
              (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active
              (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
x_params
              (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"inactive" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads) TExp Int32
ltid32
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
in_block_thread_active TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
inactive) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y) ->
                VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations
              (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active
              (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive
              (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
x_params
              (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam

      maybeBarrier :: InKernelGen ()
maybeBarrier =
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
          (TExp Int32
lockstep_width TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads)
          InKernelGen ()
barrier

  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      TExp Bool
thread_active <-
        [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_active" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
active

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (TPrimExp Int64 VName -> Param LetDecMem -> VName -> InKernelGen ()
readParam (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads)) [Param LetDecMem]
x_params [VName]
arrs
      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen ()
op_to_x TExp Bool
thread_active

      InKernelGen ()
maybeBarrier

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem]
-> [Param LetDecMem]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ()
writeResult [Param LetDecMem]
x_params [Param LetDecMem]
y_params [VName]
arrs

      InKernelGen ()
maybeBarrier

      TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
  where
    block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
    in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
    ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
    ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
    gtid :: TPrimExp Int64 VName
gtid = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
    array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam

    readInitial :: Param LetDecMem -> VName -> InKernelGen ()
readInitial Param LetDecMem
p VName
arr
      | Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
p =
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
ltid]
      | Bool
otherwise =
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
gtid]

    readParam :: TPrimExp Int64 VName -> Param LetDecMem -> VName -> InKernelGen ()
readParam TPrimExp Int64 VName
behind Param LetDecMem
p VName
arr
      | Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
p =
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind]
      | Bool
otherwise =
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
gtid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
arrs_full_size]

    writeResult :: Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ()
writeResult Param LetDecMem
x Param LetDecMem
y VName
arr = do
      Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []
      VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []

blockScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
blockScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
blockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Lambda GPUMem
renamed_lam <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
lam

  let ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
      ([Param LetDecMem]
x_params, [Param LetDecMem]
y_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam [Param LetDecMem] -> [Param LetDecMem] -> [Param LetDecMem]
forall a. [a] -> [a] -> [a]
++ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
renamed_lam)

  TExp Bool
ltid_in_bounds <- [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"ltid_in_bounds" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
w

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays [VName]
arrs

  -- The scan works by splitting the block into chunks, which are
  -- scanned separately. Typically, these chunks are at most the
  -- lockstep width, which enables barrier-free execution inside them.
  --
  -- We hardcode the chunk size here. The only requirement is that it
  -- should not be less than the square root of the block size. With
  -- 32, we will work on blocks of size 1024 or smaller, which fits
  -- every device Troels has seen. Still, it would be nicer if it were
  -- a runtime parameter. Some day.
  let chunk_size :: TExp Int32
chunk_size = TExp Int32
32
      simd_width :: TExp Int32
simd_width = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      chunk_id :: TExp Int32
chunk_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
chunk_size
      in_chunk_id :: TExp Int32
in_chunk_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
chunk_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk_size
      doInChunkScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInChunkScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag' TExp Bool
active =
        KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inChunkScan
          KernelConstants
constants
          Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag'
          TPrimExp Int64 VName
arrs_full_size
          TExp Int32
simd_width
          TExp Int32
chunk_size
          TExp Bool
active
          [VName]
arrs
          InKernelGen ()
barrier
      array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam
      barrier :: InKernelGen ()
barrier
        | Bool
array_scan =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
        | Bool
otherwise =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence

      errorsync :: InKernelGen ()
errorsync
        | Bool
array_scan =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
        | Bool
otherwise =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      block_offset :: TPrimExp Int64 VName
block_offset = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelBlockId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants

      writeBlockResult :: Param LetDecMem -> VName -> InKernelGen ()
writeBlockResult Param LetDecMem
p VName
arr
        | Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
p =
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
        | Bool
otherwise =
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

      readPrevBlockResult :: Param LetDecMem -> VName -> InKernelGen ()
readPrevBlockResult Param LetDecMem
p VName
arr
        | Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
p =
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
        | Bool
otherwise =
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]

  Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInChunkScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Bool
ltid_in_bounds Lambda GPUMem
lam
  InKernelGen ()
barrier

  let is_first_block :: TExp Bool
is_first_block = TExp Int32
chunk_id TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []

    InKernelGen ()
barrier

  let last_in_block :: TExp Bool
last_in_block = TExp Int32
in_chunk_id TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
last_in_block TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
writeBlockResult [Param LetDecMem]
x_params [VName]
arrs

  InKernelGen ()
barrier

  let first_block_seg_flag :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag = do
        TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag
        (TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TExp Bool)
 -> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool))
-> (TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
          TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)
  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment
    Text
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'"
    (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInChunkScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) Lambda GPUMem
renamed_lam

  InKernelGen ()
errorsync

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
              VName
arr
              [TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]
              (VName -> SubExp
Var VName
arr)
              [TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

    InKernelGen ()
barrier

  TExp Bool
no_carry_in <- [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"no_carry_in" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
ltid_in_bounds

  let read_carry_in :: InKernelGen ()
read_carry_in = TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y) ->
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x)) []
        (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readPrevBlockResult [Param LetDecMem]
x_params [VName]
arrs

      op_to_x :: InKernelGen ()
op_to_x
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"inactive" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
chunk_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32
ltid32
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
    -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
    -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y) ->
              VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

      write_final_result :: InKernelGen ()
write_final_result =
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" InKernelGen ()
read_carry_in
      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" InKernelGen ()
op_to_x
      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in InKernelGen ()
write_final_result

  InKernelGen ()
barrier

  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [(Param LetDecMem, Param LetDecMem, VName)]
-> ((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem]
-> [VName]
-> [(Param LetDecMem, Param LetDecMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LetDecMem]
x_params [Param LetDecMem]
y_params [VName]
arrs) (((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y, VName
arr) ->
        if Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
y
          then VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) []
          else VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

  InKernelGen ()
barrier

blockReduce ::
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
blockReduce :: TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduce TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TV Int32
offset <- [Char] -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"offset"
  TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs

blockReduceWithOffset ::
  TV Int32 ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
blockReduceWithOffset :: TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

  let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

      barrier :: InKernelGen ()
barrier
        | (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        | Bool
otherwise = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      errorsync :: InKernelGen ()
errorsync
        | (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
        | Bool
otherwise = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal

      readReduceArgument :: Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument Param LetDecMem
param VName
arr = do
        let i :: TExp Int32
i = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]

      writeReduceOpResult :: Param LetDecMem -> VName -> InKernelGen ()
writeReduceOpResult Param LetDecMem
param VName
arr =
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
param) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) []

      writeArrayOpResult :: Param LetDecMem -> VName -> InKernelGen ()
writeArrayOpResult Param LetDecMem
param VName
arr =
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LetDecMem
param) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) []

  let ([Param LetDecMem]
reduce_acc_params, [Param LetDecMem]
reduce_arr_params) =
        Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
skip_waves <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"skip_waves" (TExp Int32
1 :: Imp.TExp Int32)
  [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)

  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument [Param LetDecMem]
reduce_acc_params [VName]
arrs

  let do_reduce :: InKernelGen ()
do_reduce = Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument [Param LetDecMem]
reduce_arr_params [VName]
arrs
        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
reduce_acc_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
writeReduceOpResult [Param LetDecMem]
reduce_acc_params [VName]
arrs
      in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile InKernelGen ()
do_reduce

      wave_size :: TExp Int32
wave_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      tblock_size :: TPrimExp Int64 VName
tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
      wave_id :: TExp Int32
wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      in_wave_id :: TExp Int32
in_wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
      num_waves :: TExp Int32
num_waves = (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
tblock_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
wave_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      arg_in_bounds :: TExp Bool
arg_in_bounds = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w

      doing_in_wave_reductions :: TExp Bool
doing_in_wave_reductions =
        TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
wave_size
      apply_in_in_wave_iteration :: TExp Bool
apply_in_in_wave_iteration =
        (TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
        TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32)
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
            (TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
apply_in_in_wave_iteration)
            InKernelGen ()
in_wave_reduce
          TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2

      doing_cross_wave_reductions :: TExp Bool
doing_cross_wave_reductions =
        TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
num_waves
      is_first_thread_in_wave :: TExp Bool
is_first_thread_in_wave =
        TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      wave_not_skipped :: TExp Bool
wave_not_skipped =
        (TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      apply_in_cross_wave_iteration :: TExp Bool
apply_in_cross_wave_iteration =
        TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
is_first_thread_in_wave TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
wave_not_skipped
      cross_wave_reductions :: InKernelGen ()
cross_wave_reductions =
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          InKernelGen ()
barrier
          TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
apply_in_cross_wave_iteration InKernelGen ()
do_reduce
          TV Int32
skip_waves TV Int32 -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2

  InKernelGen ()
in_wave_reductions
  InKernelGen ()
cross_wave_reductions
  InKernelGen ()
errorsync

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Param LetDecMem -> Bool) -> [Param LetDecMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Param LetDecMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam [Param LetDecMem]
reduce_acc_params) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Copy array-typed operands to result array" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
writeArrayOpResult [Param LetDecMem]
reduce_acc_params [VName]
arrs

compileThreadOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
threadAlloc Pat (LetDec GPUMem)
Pat LetDecMem
pat SubExp
size Space
space
compileThreadOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerBugS ([Char] -> InKernelGen ()) -> [Char] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileThreadOp: cannot compile rhs of binding " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
Pat LetDecMem
pat

-- | Perform a scalar write followed by a fence.
writeAtomic ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  InKernelGen ()
writeAtomic :: VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
writeAtomic VName
dst [TPrimExp Int64 VName]
dst_is SubExp
src [TPrimExp Int64 VName]
src_is = do
  TypeBase Shape NoUniqueness
t <- Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dst_is) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
dst
  [TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t)) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
    let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
    (VName
dst_mem, Space
dst_space, Count Elements (TPrimExp Int64 VName)
dst_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
dst ([TPrimExp Int64 VName]
dst_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
    case [TPrimExp Int64 VName]
src_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is of
      [] ->
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ())
-> (AtomicOp -> KernelOp) -> AtomicOp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> AtomicOp -> KernelOp
Imp.Atomic Space
dst_space (AtomicOp -> InKernelGen ()) -> AtomicOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          PrimType
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicWrite PrimType
pt VName
dst_mem Count Elements (TPrimExp Int64 VName)
dst_offset (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
src)
      [TPrimExp Int64 VName]
_ -> do
        TV Any
tmp <- [Char] -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrimSV [Char]
"tmp" PrimType
pt
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Any
tmp) [] SubExp
src ([TPrimExp Int64 VName]
src_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ())
-> (AtomicOp -> KernelOp) -> AtomicOp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> AtomicOp -> KernelOp
Imp.Atomic Space
dst_space (AtomicOp -> InKernelGen ()) -> AtomicOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          PrimType
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicWrite PrimType
pt VName
dst_mem Count Elements (TPrimExp Int64 VName)
dst_offset (TPrimExp Any VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TV Any -> TPrimExp Any VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
tmp))

-- | Locking strategy used for an atomic update.
data Locking = Locking
  { -- | Array containing the lock.
    Locking -> VName
lockingArray :: VName,
    -- | Value for us to consider the lock free.
    Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
    -- | What to write when we lock it.
    Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
    -- | What to write when we unlock it.
    Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
    -- | A transformation from the logical lock index to the
    -- physical position in the array.  This can also be used
    -- to make the lock array smaller.
    Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
  }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate rep r =
  Space -> [VName] -> [Imp.TExp Int64] -> ImpM rep r Imp.KernelOp ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate rep r
  = -- | Supported directly by primitive.
    AtomicPrim (DoAtomicUpdate rep r)
  | -- | Can be done by efficient swaps.
    AtomicCAS (DoAtomicUpdate rep r)
  | -- | Requires explicit locking.
    AtomicLocking (Locking -> DoAtomicUpdate rep r)

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Count Imp.Elements (Imp.TExp Int64) -> Imp.Exp -> Imp.AtomicOp)

-- | Do an atomic update corresponding to a binary operator lambda.
atomicUpdateLocking ::
  AtomicBinOp ->
  Lambda GPUMem ->
  AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking :: AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda GPUMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda GPUMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda GPUMem
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64]) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
      [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv)
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket ->
        -- If the operator is a vectorised binary operator on 32/64-bit
        -- values, we can use a particularly efficient
        -- implementation. If the operator has an atomic implementation
        -- we use that, otherwise it is still a binary operator which
        -- can be implemented by atomic compare-and-swap if 32/64 bits.
        [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
          -- Common variables.
          VName
old <- [Char] -> PrimType -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
"old" PrimType
t

          (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
a [TPrimExp Int64 VName]
bucket

          case Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket_offset BinOp
op of
            Just Exp -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> KernelOp
f (Exp -> KernelOp) -> Exp -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
            Maybe (Exp -> KernelOp)
Nothing ->
              Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a VName
old [TPrimExp Int64 VName]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName
x VName -> Exp -> InKernelGen ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)
  where
    opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket' BinOp
bop = do
      let atomic :: (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> Exp
 -> AtomicOp)
-> Exp -> KernelOp
atomic VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> (Exp -> AtomicOp) -> Exp -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
f VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket'
      (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> Exp
 -> AtomicOp)
-> Exp -> KernelOp
atomic ((VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)
 -> Exp -> KernelOp)
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)
-> Maybe (Exp -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

    primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
      | ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicPrim
      | Bool
otherwise = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS

    isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe
  (VName
   -> VName
   -> Count Elements (TPrimExp Int64 VName)
   -> Exp
   -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe
   (VName
    -> VName
    -> Count Elements (TPrimExp Int64 VName)
    -> Exp
    -> AtomicOp)
 -> Bool)
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op

-- If the operator functions purely on single 32/64-bit values, we can
-- use an implementation based on CAS, no matter what the operator
-- does.
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op
  | [Prim PrimType
t] <- Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
op,
    [LParam GPUMem
xp, LParam GPUMem
_] <- Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op,
    PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64] = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS (DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv)
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [TPrimExp Int64 VName]
bucket -> do
      VName
old <- [Char] -> PrimType -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
"old" PrimType
t
      Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [TPrimExp Int64 VName]
bucket (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName LParam GPUMem
Param LetDecMem
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [LParam GPUMem
Param LetDecMem
xp] (Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op)
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op = (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> AtomicUpdate GPUMem KernelEnv
forall rep r.
(Locking -> DoAtomicUpdate rep r) -> AtomicUpdate rep r
AtomicLocking ((Locking -> DoAtomicUpdate GPUMem KernelEnv)
 -> AtomicUpdate GPUMem KernelEnv)
-> (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket -> do
  TV Int32
old <- [Char] -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"old"
  TV Bool
continue <- [Char]
-> PrimType
-> TExp Bool
-> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall {k} (t :: k) rep r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
"continue" PrimType
Bool TExp Bool
forall v. TPrimExp Bool v
true

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements (TPrimExp Int64 VName)
locks_offset) <-
    VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([TPrimExp Int64 VName]
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      (VName, Space, Count Elements (TPrimExp Int64 VName)))
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping Locking
locking [TPrimExp Int64 VName]
bucket

  -- Critical section
  let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingIsUnlocked Locking
locking)
              (TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
      lock_acquired :: TExp Bool
lock_acquired = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
old TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Locking -> TExp Int32
lockingIsUnlocked Locking
locking
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: InKernelGen ()
release_lock =
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
              (TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToUnlock Locking
locking)
      break_loop :: InKernelGen ()
break_loop = TV Bool
continue TV Bool -> TExp Bool -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  --
  -- Note the use of 'everythingVolatile' when reading and writing the
  -- buckets.  This was necessary to ensure correct execution on a
  -- newer NVIDIA GPU (RTX 2080).  The 'volatile' modifiers likely
  -- make the writes pass through the (SM-local) L1 cache, which is
  -- necessary here, because we are really doing device-wide
  -- synchronisation without atomics (naughty!).
  let ([Param LetDecMem]
acc_params, [Param LetDecMem]
_arr_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op
      bind_acc_params :: InKernelGen ()
bind_acc_params =
        InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
acc_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
acc_p, VName
arr) ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  let op_body :: InKernelGen ()
op_body =
        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
acc_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op

      do_hist :: InKernelGen ()
do_hist =
        InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            (VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([TPrimExp Int64 VName] -> VName -> SubExp -> InKernelGen ()
writeArray [TPrimExp Int64 VName]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (Param LetDecMem -> SubExp) -> [Param LetDecMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param LetDecMem -> VName) -> Param LetDecMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName) [Param LetDecMem]
acc_params

  -- While-loop: Try to insert your value
  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Bool -> TExp Bool
forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
continue) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
try_acquire_lock
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams [LParam GPUMem]
[Param LetDecMem]
acc_params
      InKernelGen ()
bind_acc_params
      InKernelGen ()
op_body
      InKernelGen ()
do_hist
      InKernelGen ()
release_lock
      InKernelGen ()
break_loop
  where
    writeArray :: [TPrimExp Int64 VName] -> VName -> SubExp -> InKernelGen ()
writeArray [TPrimExp Int64 VName]
bucket VName
arr SubExp
val = VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
writeAtomic VName
arr [TPrimExp Int64 VName]
bucket SubExp
val []

atomicUpdateCAS ::
  Space ->
  PrimType ->
  VName ->
  VName ->
  [Imp.TExp Int64] ->
  VName ->
  InKernelGen () ->
  InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [TPrimExp Int64 VName]
bucket VName
x InKernelGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  VName
assumed <- [Char] -> PrimType -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
"assumed" PrimType
t
  TV Bool
run_loop <- [Char] -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"run_loop" TExp Bool
forall v. TPrimExp Bool v
true

  -- XXX: CUDA may generate really bad code if this is not a volatile
  -- read.  Unclear why.  The later reads are volatile, so maybe
  -- that's it.
  InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
arr [TPrimExp Int64 VName]
bucket

  -- While-loop: Try to insert your value
  let (Exp -> Exp
toBits, Exp -> Exp
fromBits) =
        case PrimType
t of
          FloatType FloatType
Float16 ->
            ( \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits16" [Exp
v] PrimType
int16,
              \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits16" [Exp
v] PrimType
t
            )
          FloatType FloatType
Float32 ->
            ( \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits32" [Exp
v] PrimType
int32,
              \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits32" [Exp
v] PrimType
t
            )
          FloatType FloatType
Float64 ->
            ( \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits64" [Exp
v] PrimType
int64,
              \Exp
v -> [Char] -> [Exp] -> PrimType -> Exp
forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits64" [Exp
v] PrimType
t
            )
          PrimType
_ -> (Exp -> Exp
forall a. a -> a
id, Exp -> Exp
forall a. a -> a
id)

      int :: PrimType
int
        | PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16 = PrimType
int16
        | PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = PrimType
int32
        | Bool
otherwise = PrimType
int64

  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Bool -> TExp Bool
forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
run_loop) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
assumed VName -> Exp -> InKernelGen ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    VName
x VName -> Exp -> InKernelGen ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t
    InKernelGen ()
do_op
    VName
old_bits_v <- [Char] -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"old_bits"
    VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
old_bits_v PrimType
int
    let old_bits :: Exp
old_bits = VName -> PrimType -> Exp
Imp.var VName
old_bits_v PrimType
int
    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ())
-> (AtomicOp -> KernelOp) -> AtomicOp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> InKernelGen ()) -> AtomicOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
        PrimType
int
        VName
old_bits_v
        VName
arr'
        Count Elements (TPrimExp Int64 VName)
bucket_offset
        (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t))
        (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))
    VName
old VName -> Exp -> InKernelGen ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp -> Exp
fromBits Exp
old_bits
    let won :: Exp
won = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq PrimType
int) (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t)) Exp
old_bits
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Exp -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool Exp
won) (TV Bool
run_loop TV Bool -> TExp Bool -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false)

computeKernelUses ::
  (FreeIn a) =>
  a ->
  [VName] ->
  CallKernelGen [Imp.KernelUse]
computeKernelUses :: forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
  let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
  -- Compute the variables that we need to pass to the kernel.
  [KernelUse] -> [KernelUse]
forall a. Ord a => [a] -> [a]
nubOrd ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet = ([Maybe KernelUse] -> [KernelUse])
-> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
 -> CallKernelGen [KernelUse])
-> (Names -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse])
-> Names
-> CallKernelGen [KernelUse]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ImpM GPUMem HostEnv HostOp (Maybe KernelUse))
-> [VName] -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ImpM GPUMem HostEnv HostOp (Maybe KernelUse)
forall {r} {op}. VName -> ImpM GPUMem r op (Maybe KernelUse)
f ([VName] -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse])
-> (Names -> [VName])
-> Names
-> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
  where
    f :: VName -> ImpM GPUMem r op (Maybe KernelUse)
f VName
var = do
      TypeBase Shape NoUniqueness
t <- VName -> ImpM GPUMem r op (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
var
      VTable GPUMem
vtable <- ImpM GPUMem r op (VTable GPUMem)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
      case TypeBase Shape NoUniqueness
t of
        Array {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Acc {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Mem (Space [Char]
"shared") -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Mem {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
        Prim PrimType
bt ->
          VTable GPUMem -> Exp -> ImpM GPUMem r op (Maybe KernelConstExp)
forall rep r op.
VTable GPUMem -> Exp -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable (VName -> PrimType -> Exp
Imp.var VName
var PrimType
bt) ImpM GPUMem r op (Maybe KernelConstExp)
-> (Maybe KernelConstExp -> ImpM GPUMem r op (Maybe KernelUse))
-> ImpM GPUMem r op (Maybe KernelUse)
forall a b.
ImpM GPUMem r op a
-> (a -> ImpM GPUMem r op b) -> ImpM GPUMem r op b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Just KernelConstExp
ce -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
            Maybe KernelConstExp
Nothing -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt

isConstExp ::
  VTable GPUMem ->
  Imp.Exp ->
  ImpM rep r op (Maybe Imp.KernelConstExp)
isConstExp :: forall rep r op.
VTable GPUMem -> Exp -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable Exp
size = do
  Maybe Name
fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let onLeaf :: VName -> PrimType -> Maybe KernelConstExp
onLeaf VName
name PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
      lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
        Exp GPUMem -> Maybe KernelConstExp
constExp (Exp GPUMem -> Maybe KernelConstExp)
-> Maybe (Exp GPUMem) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry GPUMem -> Maybe (Exp GPUMem)
forall {rep}. VarEntry rep -> Maybe (Exp rep)
hasExp (VarEntry GPUMem -> Maybe (Exp GPUMem))
-> Maybe (VarEntry GPUMem) -> Maybe (Exp GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable GPUMem -> Maybe (VarEntry GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable GPUMem
vtable
      constExp :: Exp GPUMem -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize Name
key SizeClass
c)))) =
        KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> SizeClass -> KernelConst
Imp.SizeConst (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
c) PrimType
int32
      constExp (Op (Inner (SizeOp (GetSizeMax SizeClass
c)))) =
        KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
c) PrimType
int32
      constExp Exp GPUMem
e = (VName -> Maybe KernelConstExp)
-> Exp GPUMem -> Maybe KernelConstExp
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp Exp GPUMem
e
  Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (VName -> PrimType -> Maybe KernelConstExp)
-> Exp -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM VName -> PrimType -> Maybe KernelConstExp
onLeaf Exp
size
  where
    hasExp :: VarEntry rep -> Maybe (Exp rep)
hasExp (ArrayVar Maybe (Exp rep)
e ArrayEntry
_) = Maybe (Exp rep)
e
    hasExp (AccVar Maybe (Exp rep)
e (VName, Shape, [TypeBase Shape NoUniqueness])
_) = Maybe (Exp rep)
e
    hasExp (ScalarVar Maybe (Exp rep)
e ScalarEntry
_) = Maybe (Exp rep)
e
    hasExp (MemVar Maybe (Exp rep)
e MemEntry
_) = Maybe (Exp rep)
e

kernelInitialisationSimple ::
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumBlocks SubExp
-> Count BlockSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size = do
  VName
global_tid <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"global_tid"
  VName
local_tid <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"local_tid"
  VName
tblock_id <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"block_id"
  VName
wave_size <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"wave_size"
  VName
inner_tblock_size <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tblock_size"
  let num_tblocks' :: TPrimExp Int64 VName
num_tblocks' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks)
      tblock_size' :: TPrimExp Int64 VName
tblock_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size)
      constants :: KernelConstants
constants =
        KernelConstants
          { kernelGlobalThreadIdVar :: TV Int32
kernelGlobalThreadIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
global_tid,
            kernelLocalThreadIdVar :: TV Int32
kernelLocalThreadIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
local_tid,
            kernelBlockIdVar :: TV Int32
kernelBlockIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
tblock_id,
            kernelNumBlocksCount :: Count NumBlocks SubExp
kernelNumBlocksCount = Count NumBlocks SubExp
num_tblocks,
            kernelBlockSizeCount :: Count BlockSize SubExp
kernelBlockSizeCount = Count BlockSize SubExp
tblock_size,
            kernelNumBlocks :: TPrimExp Int64 VName
kernelNumBlocks = TPrimExp Int64 VName
num_tblocks',
            kernelBlockSize :: TPrimExp Int64 VName
kernelBlockSize = TPrimExp Int64 VName
tblock_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_tblocks'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
wave_size,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Map [SubExp] (TExp Int32)
forall a. Monoid a => a
mempty
          }

  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
local_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_tblock_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
wave_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
tblock_id PrimType
int32

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_tblock_size Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetBlockId VName
tblock_id Int
0)
        VName -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
global_tid (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
tblock_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_tblock_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
local_tid

  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelConstants
constants, InKernelGen ()
set_constants)

isActive :: [(VName, SubExp)] -> Imp.TExp Bool
isActive :: [(VName, SubExp)] -> TExp Bool
isActive [(VName, SubExp)]
limit = case [TExp Bool]
actives of
  [] -> TExp Bool
forall v. TPrimExp Bool v
true
  TExp Bool
x : [TExp Bool]
xs -> (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
x [TExp Bool]
xs
  where
    ([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
    actives :: [TExp Bool]
actives = (VName -> TPrimExp Int64 VName -> TExp Bool)
-> [VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TPrimExp Int64 VName -> TExp Bool
forall {v}. Eq v => v -> TPrimExp Int64 v -> TPrimExp Bool v
active [VName]
is ([TPrimExp Int64 VName] -> [TExp Bool])
-> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws
    active :: v -> TPrimExp Int64 v -> TPrimExp Bool v
active v
i = (v -> TPrimExp Int64 v
forall a. a -> TPrimExp Int64 a
Imp.le64 v
i .<.)

-- | Change every memory block to be in the global address space,
-- except those who are in the shared memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and declared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  Space
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a
forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace ([Char] -> Space
Imp.Space [Char]
"global") (ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a)
-> (ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a)
-> ImpM GPUMem HostEnv HostOp a
-> ImpM GPUMem HostEnv HostOp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable GPUMem -> VTable GPUMem)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a
forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable ((VarEntry GPUMem -> VarEntry GPUMem)
-> VTable GPUMem -> VTable GPUMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry GPUMem -> VarEntry GPUMem
forall {rep}. VarEntry rep -> VarEntry rep
globalMemory)
  where
    globalMemory :: VarEntry rep -> VarEntry rep
globalMemory (MemVar Maybe (Exp rep)
_ MemEntry
entry)
      | MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= [Char] -> Space
Space [Char]
"shared" =
          Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing MemEntry
entry {entryMemSpace = Imp.Space "global"}
    globalMemory VarEntry rep
entry =
      VarEntry rep
entry

simpleKernelBlocks ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  CallKernelGen (Imp.TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp)
simpleKernelBlocks :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp)
simpleKernelBlocks TPrimExp Int64 VName
max_num_tblocks TPrimExp Int64 VName
kernel_size = do
  TV Int64
tblock_size <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"tblock_size"
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let tblock_size_key :: Name
tblock_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
tblock_size
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
tblock_size) Name
tblock_size_key SizeClass
Imp.SizeThreadBlock
  TPrimExp Int64 VName
virt_num_tblocks <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"virt_num_tblocks" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
kernel_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
tblock_size
  TV Int64
num_tblocks <- [Char]
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_tblocks" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
virt_num_tblocks TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMin64` TPrimExp Int64 VName
max_num_tblocks
  (TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp)
-> CallKernelGen
     (TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
virt_num_tblocks, SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count NumBlocks SubExp)
-> SubExp -> Count NumBlocks SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_tblocks, SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
tblock_size)

simpleKernelConstants ::
  Imp.TExp Int64 ->
  String ->
  CallKernelGen
    ( (Imp.TExp Int64 -> InKernelGen ()) -> InKernelGen (),
      KernelConstants
    )
simpleKernelConstants :: TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
kernel_size [Char]
desc = do
  -- For performance reasons, codegen assumes that the thread count is
  -- never more than will fit in an i32.  This means we need to cap
  -- the number of blocks here.  The cap is set much higher than any
  -- GPU will possibly need.  Feel free to come back and laugh at me
  -- in the future.
  let max_num_tblocks :: TPrimExp Int64 VName
max_num_tblocks = TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
1024
  VName
thread_gtid <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> CallKernelGen VName) -> [Char] -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_gtid"
  VName
thread_ltid <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> CallKernelGen VName) -> [Char] -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_ltid"
  VName
tblock_id <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> CallKernelGen VName) -> [Char] -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_gid"
  VName
inner_tblock_size <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tblock_size"
  (TExp Int32
virt_num_tblocks, Count NumBlocks SubExp
num_tblocks, Count BlockSize SubExp
tblock_size) <-
    TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp)
simpleKernelBlocks TPrimExp Int64 VName
max_num_tblocks TPrimExp Int64 VName
kernel_size
  let tblock_size' :: TPrimExp Int64 VName
tblock_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size
      num_tblocks' :: TPrimExp Int64 VName
num_tblocks' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks

      constants :: KernelConstants
constants =
        KernelConstants
          { kernelGlobalThreadIdVar :: TV Int32
kernelGlobalThreadIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
thread_gtid,
            kernelLocalThreadIdVar :: TV Int32
kernelLocalThreadIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
thread_ltid,
            kernelBlockIdVar :: TV Int32
kernelBlockIdVar = VName -> TV Int32
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
tblock_id,
            kernelNumBlocksCount :: Count NumBlocks SubExp
kernelNumBlocksCount = Count NumBlocks SubExp
num_tblocks,
            kernelBlockSizeCount :: Count BlockSize SubExp
kernelBlockSizeCount = Count BlockSize SubExp
tblock_size,
            kernelNumBlocks :: TPrimExp Int64 VName
kernelNumBlocks = TPrimExp Int64 VName
num_tblocks',
            kernelBlockSize :: TPrimExp Int64 VName
kernelBlockSize = TPrimExp Int64 VName
tblock_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_tblocks'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = TExp Int32
0,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Map [SubExp] (TExp Int32)
forall a. Monoid a => a
mempty
          }

      wrapKernel :: (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel TPrimExp Int64 VName -> InKernelGen ()
m = do
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
thread_ltid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_tblock_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
tblock_id PrimType
int32
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_tblock_size Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetBlockId VName
tblock_id Int
0)
        VName -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
thread_gtid (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
tblock_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_tblock_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
thread_ltid
        SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseBlocks SegVirt
SegVirt TExp Int32
virt_num_tblocks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
virt_tblock_id -> do
          TPrimExp Int64 VName
global_tid <-
            [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_tid" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
virt_tblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_tblock_size)
                TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
          TPrimExp Int64 VName -> InKernelGen ()
m TPrimExp Int64 VName
global_tid

  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
 KernelConstants)
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel, KernelConstants
constants)

-- | For many kernels, we may not have enough physical blocks to cover
-- the logical iteration space.  Some blocks thus have to perform
-- double duty; we put an outer loop to accomplish this.  The
-- advantage over just launching a bazillion threads is that the cost
-- of memory expansion should be proportional to the number of
-- *physical* threads (hardware parallelism), not the amount of
-- application parallelism.
virtualiseBlocks ::
  SegVirt ->
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> InKernelGen ()) ->
  InKernelGen ()
virtualiseBlocks :: SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseBlocks SegVirt
SegVirt TExp Int32
required_blocks TExp Int32 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TV Int32
phys_tblock_id <- [Char] -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"phys_tblock_id"
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetBlockId (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
phys_tblock_id) Int
0
  TExp Int32
iterations <-
    [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      (TExp Int32
required_blocks TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
phys_tblock_id) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumBlocks KernelConstants
constants)

  [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" TExp Int32
iterations ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
    TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ())
-> (TV Int32 -> TExp Int32) -> TV Int32 -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp
      (TV Int32 -> InKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp (TV Int32) -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV
        [Char]
"virt_tblock_id"
        (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
phys_tblock_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumBlocks KernelConstants
constants))
    -- Make sure the virtual block is actually done before we let
    -- another virtual block have its way with it.
    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
virtualiseBlocks SegVirt
_ TExp Int32
_ TExp Int32 -> InKernelGen ()
m =
  TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ())
-> (KernelEnv -> TExp Int32) -> KernelEnv -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp (TV Int32 -> TExp Int32)
-> (KernelEnv -> TV Int32) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TV Int32
kernelBlockIdVar (KernelConstants -> TV Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TV Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> InKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp KernelEnv -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

-- | Various extra configuration of the kernel being generated.
data KernelAttrs = KernelAttrs
  { -- | Can this kernel execute correctly even if previous kernels failed?
    KernelAttrs -> Bool
kAttrFailureTolerant :: Bool,
    -- | Does whatever launch this kernel check for shared memory capacity itself?
    KernelAttrs -> Bool
kAttrCheckSharedMemory :: Bool,
    -- | Number of blocks.
    KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks :: Count NumBlocks SubExp,
    -- | Block size.
    KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize :: Count BlockSize SubExp,
    -- | Variables that are specially in scope inside the kernel.
    -- Operationally, these will be available at kernel compile time
    -- (which happens at run-time, with access to machine-specific
    -- information).
    KernelAttrs -> Map VName KernelConstExp
kAttrConstExps :: M.Map VName Imp.KernelConstExp
  }

-- | The default kernel attributes.
defKernelAttrs ::
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  KernelAttrs
defKernelAttrs :: Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size =
  KernelAttrs
    { kAttrFailureTolerant :: Bool
kAttrFailureTolerant = Bool
False,
      kAttrCheckSharedMemory :: Bool
kAttrCheckSharedMemory = Bool
True,
      kAttrNumBlocks :: Count NumBlocks SubExp
kAttrNumBlocks = Count NumBlocks SubExp
num_tblocks,
      kAttrBlockSize :: Count BlockSize SubExp
kAttrBlockSize = Count BlockSize SubExp
tblock_size,
      kAttrConstExps :: Map VName KernelConstExp
kAttrConstExps = Map VName KernelConstExp
forall a. Monoid a => a
mempty
    }

-- | Retrieve a size of the given size class and put it in a variable
-- with the given name.
getSize :: String -> SizeClass -> CallKernelGen (TV Int64)
getSize :: [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
desc SizeClass
size_class = do
  TV Int64
v <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
desc
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let v_key :: Name
v_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
v
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
v) Name
v_key SizeClass
size_class
  TV Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TV Int64
v

-- | Compute kernel attributes from 'SegLevel'; including synthesising
-- block-size and thread count if no grid is provided.
lvlKernelAttrs :: SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs :: SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl =
  case SegLevel
lvl of
    SegThread SegVirt
_ Maybe KernelGrid
Nothing -> CallKernelGen KernelAttrs
mkGrid
    SegThread SegVirt
_ (Just (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size)) ->
      KernelAttrs -> CallKernelGen KernelAttrs
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelAttrs -> CallKernelGen KernelAttrs)
-> KernelAttrs -> CallKernelGen KernelAttrs
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size
    SegBlock SegVirt
_ Maybe KernelGrid
Nothing -> CallKernelGen KernelAttrs
mkGrid
    SegBlock SegVirt
_ (Just (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size)) ->
      KernelAttrs -> CallKernelGen KernelAttrs
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelAttrs -> CallKernelGen KernelAttrs)
-> KernelAttrs -> CallKernelGen KernelAttrs
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size
    SegThreadInBlock {} ->
      [Char] -> CallKernelGen KernelAttrs
forall a. HasCallStack => [Char] -> a
error [Char]
"lvlKernelAttrs: SegThreadInBlock"
  where
    mkGrid :: CallKernelGen KernelAttrs
mkGrid = do
      TV Int64
tblock_size <- [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
"tblock_size" SizeClass
Imp.SizeThreadBlock
      TV Int64
num_tblocks <- [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
"num_tblocks" SizeClass
Imp.SizeGrid
      KernelAttrs -> CallKernelGen KernelAttrs
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelAttrs -> CallKernelGen KernelAttrs)
-> KernelAttrs -> CallKernelGen KernelAttrs
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count NumBlocks SubExp)
-> SubExp -> Count NumBlocks SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
tblock_size)

sKernel ::
  Operations GPUMem KernelEnv Imp.KernelOp ->
  (KernelConstants -> Imp.TExp Int64) ->
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernel :: Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TPrimExp Int64 VName)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
ops KernelConstants -> TPrimExp Int64 VName
flatf [Char]
name VName
v KernelAttrs
attrs InKernelGen ()
f = do
  (KernelConstants
constants, InKernelGen ()
set_constants) <-
    Count NumBlocks SubExp
-> Count BlockSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks KernelAttrs
attrs) (KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize KernelAttrs
attrs)
  Name
name' <- [Char] -> ImpM GPUMem HostEnv HostOp Name
forall rep r op. [Char] -> ImpM rep r op Name
nameForFun ([Char] -> ImpM GPUMem HostEnv HostOp Name)
-> [Char] -> ImpM GPUMem HostEnv HostOp Name
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (VName -> Int
baseTag VName
v)
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name' (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
flatf KernelConstants
constants
    InKernelGen ()
f

sKernelThread ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelThread :: [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TPrimExp Int64 VName)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
threadOperations ((KernelConstants -> TPrimExp Int64 VName)
 -> [Char]
 -> VName
 -> KernelAttrs
 -> InKernelGen ()
 -> CallKernelGen ())
-> (KernelConstants -> TPrimExp Int64 VName)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelConstants -> TExp Int32)
-> KernelConstants
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelGlobalThreadId

sKernelOp ::
  KernelAttrs ->
  KernelConstants ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelOp :: KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m = do
  HostEnv AtomicBinOp
atomics Target
_ Map VName Locks
locks <- ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations GPUMem KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> Map VName Locks -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants Map VName Locks
locks) Operations GPUMem KernelEnv KernelOp
ops InKernelGen ()
m
  [KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body ([VName] -> CallKernelGen [KernelUse])
-> [VName] -> CallKernelGen [KernelUse]
forall a b. (a -> b) -> a -> b
$ Map VName KernelConstExp -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName KernelConstExp -> [VName])
-> Map VName KernelConstExp -> [VName]
forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Map VName KernelConstExp
kAttrConstExps KernelAttrs
attrs
  Either Exp KernelConstExp
tblock_size <- TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (Either Exp KernelConstExp)
forall {k} {t :: k} {r} {op}.
TPrimExp t VName -> ImpM GPUMem r op (Either Exp KernelConstExp)
onBlockSize (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (Either Exp KernelConstExp))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (Either Exp KernelConstExp)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> (Kernel -> Code HostOp) -> Kernel -> CallKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp)
-> (Kernel -> HostOp) -> Kernel -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel (Kernel -> CallKernelGen ()) -> Kernel -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Imp.Kernel
      { kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body,
        kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses [KernelUse] -> [KernelUse] -> [KernelUse]
forall a. Semigroup a => a -> a -> a
<> ((VName, KernelConstExp) -> KernelUse)
-> [(VName, KernelConstExp)] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map (VName, KernelConstExp) -> KernelUse
constToUse (Map VName KernelConstExp -> [(VName, KernelConstExp)]
forall k a. Map k a -> [(k, a)]
M.toList (KernelAttrs -> Map VName KernelConstExp
kAttrConstExps KernelAttrs
attrs)),
        kernelNumBlocks :: [Exp]
Imp.kernelNumBlocks = [TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumBlocks KernelConstants
constants],
        kernelBlockSize :: [Either Exp KernelConstExp]
Imp.kernelBlockSize = [Either Exp KernelConstExp
tblock_size],
        kernelName :: Name
Imp.kernelName = Name
name,
        kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = KernelAttrs -> Bool
kAttrFailureTolerant KernelAttrs
attrs,
        kernelCheckSharedMemory :: Bool
Imp.kernelCheckSharedMemory = KernelAttrs -> Bool
kAttrCheckSharedMemory KernelAttrs
attrs
      }
  where
    -- Figure out if this expression actually corresponds to a
    -- KernelConst.
    onBlockSize :: TPrimExp t VName -> ImpM GPUMem r op (Either Exp KernelConstExp)
onBlockSize TPrimExp t VName
e = do
      VTable GPUMem
vtable <- ImpM GPUMem r op (VTable GPUMem)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
      Maybe KernelConstExp
x <- VTable GPUMem -> Exp -> ImpM GPUMem r op (Maybe KernelConstExp)
forall rep r op.
VTable GPUMem -> Exp -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable (Exp -> ImpM GPUMem r op (Maybe KernelConstExp))
-> Exp -> ImpM GPUMem r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
e
      Either Exp KernelConstExp
-> ImpM GPUMem r op (Either Exp KernelConstExp)
forall a. a -> ImpM GPUMem r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Exp KernelConstExp
 -> ImpM GPUMem r op (Either Exp KernelConstExp))
-> Either Exp KernelConstExp
-> ImpM GPUMem r op (Either Exp KernelConstExp)
forall a b. (a -> b) -> a -> b
$
        case Maybe KernelConstExp
x of
          Just KernelConstExp
kc -> KernelConstExp -> Either Exp KernelConstExp
forall a b. b -> Either a b
Right KernelConstExp
kc
          Maybe KernelConstExp
_ -> Exp -> Either Exp KernelConstExp
forall a b. a -> Either a b
Left (Exp -> Either Exp KernelConstExp)
-> Exp -> Either Exp KernelConstExp
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
e

    constToUse :: (VName, KernelConstExp) -> KernelUse
constToUse (VName
v, KernelConstExp
e) = VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
v KernelConstExp
e

sKernelFailureTolerant ::
  Bool ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  KernelConstants ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelFailureTolerant Bool
tol Operations GPUMem KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m
  where
    attrs :: KernelAttrs
attrs =
      ( Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs
          (KernelConstants -> Count NumBlocks SubExp
kernelNumBlocksCount KernelConstants
constants)
          (KernelConstants -> Count BlockSize SubExp
kernelBlockSizeCount KernelConstants
constants)
      )
        { kAttrFailureTolerant = tol
        }

threadOperations :: Operations GPUMem KernelEnv Imp.KernelOp
threadOperations :: Operations GPUMem KernelEnv KernelOp
threadOperations =
  (OpCompiler GPUMem KernelEnv KernelOp
-> Operations GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp)
    { opsCopyCompiler = lmadCopy,
      opsExpCompiler = compileThreadExp,
      opsStmsCompiler = \Names
_ -> Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
      opsAllocCompilers =
        M.fromList [(Space "shared", allocLocal)]
    }

-- | Perform a Replicate with a kernel.
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel VName
arr SubExp
se = do
  TypeBase Shape NoUniqueness
t <- SubExp -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
se
  [SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) ([SubExp] -> [SubExp])
-> (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr

  let dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
  TPrimExp Int64 VName
n <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"replicate_n" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 [TPrimExp Int64 VName]
dims
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n [Char]
"replicate"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$
            [Char]
"replicate_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar (TV Int32 -> VName) -> TV Int32 -> VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TV Int32
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> do
      [TPrimExp Int64 VName]
is' <- [Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"rep_i" [TPrimExp Int64 VName]
dims TPrimExp Int64 VName
gtid
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
se ([TPrimExp Int64 VName] -> InKernelGen ())
-> [TPrimExp Int64 VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [TPrimExp Int64 VName]
is'

replicateName :: PrimType -> String
replicateName :: PrimType -> [Char]
replicateName PrimType
bt = [Char]
"replicate_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
replicateName PrimType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem"
    VName
num_elems <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"num_elems"
    VName
val <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"val"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem ([Char] -> Space
Space [Char]
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int64,
            VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt
          ]
        shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
    Name -> [Param] -> [Param] -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        [Char] -> PrimType -> Shape -> VName -> LMAD -> CallKernelGen VName
forall rep r op.
[Char] -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray [Char]
"arr" PrimType
bt Shape
shape VName
mem (LMAD -> CallKernelGen VName) -> LMAD -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName -> SubExp -> CallKernelGen ()
sReplicateKernel VName
arr (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val

  Name -> ImpM GPUMem HostEnv HostOp Name
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill VName
arr SubExp
v = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
arr_shape LMAD
arr_lmad) PrimType
_ <- VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase Shape NoUniqueness
v_t <- SubExp -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
v
  case TypeBase Shape NoUniqueness
v_t of
    Prim PrimType
v_t'
      | LMAD -> Bool
forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
LMAD.isDirect LMAD
arr_lmad -> Maybe (CallKernelGen ())
-> CallKernelGen (Maybe (CallKernelGen ()))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (CallKernelGen ())
 -> CallKernelGen (Maybe (CallKernelGen ())))
-> Maybe (CallKernelGen ())
-> CallKernelGen (Maybe (CallKernelGen ()))
forall a b. (a -> b) -> a -> b
$
          CallKernelGen () -> Maybe (CallKernelGen ())
forall a. a -> Maybe a
Just (CallKernelGen () -> Maybe (CallKernelGen ()))
-> CallKernelGen () -> Maybe (CallKernelGen ())
forall a b. (a -> b) -> a -> b
$ do
            Name
fname <- PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
v_t'
            Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
                []
                Name
fname
                [ VName -> Arg
Imp.MemArg VName
arr_mem,
                  Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
arr_shape,
                  Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
v_t' SubExp
v
                ]
    TypeBase Shape NoUniqueness
_ -> Maybe (CallKernelGen ())
-> CallKernelGen (Maybe (CallKernelGen ()))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (CallKernelGen ())
forall a. Maybe a
Nothing

-- | Perform a Replicate with a kernel.
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate VName
arr SubExp
se = do
  -- If the replicate is of a particularly common and simple form
  -- (morally a memset()/fill), then we use a common function.
  Maybe (CallKernelGen ())
is_fill <- VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill VName
arr SubExp
se

  case Maybe (CallKernelGen ())
is_fill of
    Just CallKernelGen ()
m -> CallKernelGen ()
m
    Maybe (CallKernelGen ())
Nothing -> VName -> SubExp -> CallKernelGen ()
sReplicateKernel VName
arr SubExp
se

-- | Perform an Iota with a kernel.
sIotaKernel ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIotaKernel :: VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIotaKernel VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et = do
  MemLoc
destloc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n [Char]
"iota"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$
            [Char]
"iota_"
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ IntType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString IntType
et
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_"
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar (TV Int32 -> VName) -> TV Int32 -> VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TV Int32
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid ->
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        (VName
destmem, Space
destspace, Count Elements (TPrimExp Int64 VName)
destidx) <- MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
destloc [TPrimExp Int64 VName
gtid]

        Code KernelOp -> InKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TPrimExp Int64 VName)
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp
              (IntType -> Overflow -> BinOp
Add IntType
et Overflow
OverflowWrap)
              (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
et Overflow
OverflowWrap) (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
Imp.sExt IntType
et (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
gtid) Exp
s)
              Exp
x

iotaName :: IntType -> String
iotaName :: IntType -> [Char]
iotaName IntType
bt = [Char]
"iota_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ IntType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString IntType
bt

iotaForType :: IntType -> CallKernelGen Name
iotaForType :: IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> IntType -> [Char]
iotaName IntType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem"
    VName
n <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"n"
    VName
x <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"x"
    VName
s <- [Char] -> CallKernelGen VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"s"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem ([Char] -> Space
Space [Char]
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
n PrimType
int64,
            VName -> PrimType -> Param
Imp.ScalarParam VName
x (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt,
            VName -> PrimType -> Param
Imp.ScalarParam VName
s (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
          ]
        shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
n]
        n' :: TPrimExp Int64 VName
n' = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
n
        x' :: Exp
x' = VName -> PrimType -> Exp
Imp.var VName
x (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
        s' :: Exp
s' = VName -> PrimType -> Exp
Imp.var VName
s (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt

    Name -> [Param] -> [Param] -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        [Char] -> PrimType -> Shape -> VName -> LMAD -> CallKernelGen VName
forall rep r op.
[Char] -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray [Char]
"arr" (IntType -> PrimType
IntType IntType
bt) Shape
shape VName
mem (LMAD -> CallKernelGen VName) -> LMAD -> CallKernelGen VName
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
      VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIotaKernel VName
arr TPrimExp Int64 VName
n' Exp
x' Exp
s' IntType
bt

  Name -> ImpM GPUMem HostEnv HostOp Name
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

-- | Perform an Iota with a kernel.
sIota ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIota :: VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIota VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
_ LMAD
arr_lmad) PrimType
_ <- VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  if LMAD -> Bool
forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
LMAD.isDirect LMAD
arr_lmad
    then do
      Name
fname <- IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
et
      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
          [VName -> Arg
Imp.MemArg VName
arr_mem, Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
n, Exp -> Arg
Imp.ExpArg Exp
x, Exp -> Arg
Imp.ExpArg Exp
s]
    else VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIotaKernel VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et

compileThreadResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileThreadResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
_ PatElem LetDecMem
_ RegTileReturns {} =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerLimitationS [Char]
"compileThreadResult: RegTileReturns not yet handled."
compileThreadResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  let is :: [TPrimExp Int64 VName]
is = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
is SubExp
what []
compileThreadResult SegSpace
_ PatElem LetDecMem
pe (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
dests) = do
  TypeBase Shape NoUniqueness
arr_t <- VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
  let rws' :: [TPrimExp Int64 VName]
rws' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
arr_t
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
dests (((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
    let slice' :: Slice (TPrimExp Int64 VName)
slice' = (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
        write :: TExp Bool
write = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
rws'
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
write (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (Slice (TPrimExp Int64 VName) -> [DimIndex (TPrimExp Int64 VName)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
e []
compileThreadResult SegSpace
_ PatElem LetDecMem
_ TileReturns {} =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerBugS [Char]
"compileThreadResult: TileReturns unhandled."