{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen.GPU.Base
  ( KernelConstants (..),
    threadOperations,
    keyWithEntryPoint,
    CallKernelGen,
    InKernelGen,
    Locks (..),
    HostEnv (..),
    Target (..),
    KernelEnv (..),
    groupReduce,
    groupScan,
    groupLoop,
    isActive,
    sKernel,
    sKernelThread,
    KernelAttrs (..),
    defKernelAttrs,
    allocLocal,
    kernelAlloc,
    compileThreadResult,
    virtualiseGroups,
    kernelLoop,
    groupCoverSpace,
    fenceForArrays,
    updateAcc,

    -- * Host-level bulk operations
    sReplicate,
    sIota,
    sRotateKernel,
    sCopy,

    -- * Atomics
    AtomicBinOp,
    atomicUpdateLocking,
    Locking (..),
    AtomicUpdate (..),
    DoAtomicUpdate,
  )
where

import Control.Monad.Except
import Data.List (foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (dropLast, nubOrd, splitFromEnd)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | Which target are we ultimately generating code for?  While most
-- of the kernels code is the same, there are some cases where we
-- generate special code based on the ultimate low-level API we are
-- targeting.
data Target = CUDA | OpenCL

-- | Information about the locks available for accumulators.
data Locks = Locks
  { Locks -> VName
locksArray :: VName,
    Locks -> Int
locksCount :: Int
  }

data HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp,
    HostEnv -> Target
hostTarget :: Target,
    HostEnv -> Map VName Locks
hostLocks :: M.Map VName Locks
  }

data KernelEnv = KernelEnv
  { KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp,
    KernelEnv -> KernelConstants
kernelConstants :: KernelConstants,
    KernelEnv -> Map VName Locks
kernelLocks :: M.Map VName Locks
  }

type CallKernelGen = ImpM GPUMem HostEnv Imp.HostOp

type InKernelGen = ImpM GPUMem KernelEnv Imp.KernelOp

data KernelConstants = KernelConstants
  { KernelConstants -> TExp Int32
kernelGlobalThreadId :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelLocalThreadId :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelGroupId :: Imp.TExp Int32,
    KernelConstants -> VName
kernelGlobalThreadIdVar :: VName,
    KernelConstants -> VName
kernelLocalThreadIdVar :: VName,
    KernelConstants -> VName
kernelGroupIdVar :: VName,
    KernelConstants -> Count NumGroups SubExp
kernelNumGroupsCount :: Count NumGroups SubExp,
    KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount :: Count GroupSize SubExp,
    KernelConstants -> TPrimExp Int64 VName
kernelNumGroups :: Imp.TExp Int64,
    KernelConstants -> TPrimExp Int64 VName
kernelGroupSize :: Imp.TExp Int64,
    KernelConstants -> TExp Int32
kernelNumThreads :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelWaveSize :: Imp.TExp Int32,
    -- | A mapping from dimensions of nested SegOps to already
    -- computed local thread IDs.  Only valid in non-virtualised case.
    KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap :: M.Map [SubExp] [Imp.TExp Int32],
    -- | Mapping from dimensions of nested SegOps to how many
    -- iterations the virtualisation loop needs.
    KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
  [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ((forall a. [a] -> [a] -> [a]
++ [Char]
".") forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> [Char]
nameToString) Maybe Name
fname forall a. [a] -> [a] -> [a]
++ Name -> [Char]
nameToString Name
key

allocLocal :: AllocCompiler GPUMem r Imp.KernelOp
allocLocal :: forall r. AllocCompiler GPUMem r KernelOp
allocLocal VName
mem Count Bytes (TPrimExp Int64 VName)
size =
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TPrimExp Int64 VName) -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes (TPrimExp Int64 VName)
size

kernelAlloc ::
  Pat LetDecMem ->
  SubExp ->
  Space ->
  InKernelGen ()
kernelAlloc :: Pat LParamMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pat [PatElem LParamMem
_]) SubExp
_ ScalarSpace {} =
  -- Handled by the declaration of the memory block, which is then
  -- translated to an actual scalar variable during C code generation.
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
kernelAlloc (Pat [PatElem LParamMem
mem]) SubExp
size (Space [Char]
"local") =
  forall r. AllocCompiler GPUMem r KernelOp
allocLocal (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
mem) forall a b. (a -> b) -> a -> b
$ forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
size
kernelAlloc (Pat [PatElem LParamMem
mem]) SubExp
_ Space
_ =
  forall a. [Char] -> a
compilerLimitationS forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate memory block " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PatElem LParamMem
mem forall a. [a] -> [a] -> [a]
++ [Char]
" in kernel."
kernelAlloc Pat LParamMem
dest SubExp
_ Space
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid target for in-kernel allocation: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Pat LParamMem
dest

updateAcc :: VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc :: VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs = forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" forall a b. (a -> b) -> a -> b
$ do
  -- See the ImpGen implementation of UpdateAcc for general notes.
  let is' :: [TPrimExp Int64 VName]
is' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
is
  (VName
c, Space
space, [VName]
arrs, [TPrimExp Int64 VName]
dims, Maybe (Lambda GPUMem)
op) <- forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], [TPrimExp Int64 VName], Maybe (Lambda rep))
lookupAcc VName
acc [TPrimExp Int64 VName]
is'
  forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
is')) [TPrimExp Int64 VName]
dims) forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda GPUMem)
op of
      Maybe (Lambda GPUMem)
Nothing ->
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
v []
      Just Lambda GPUMem
lam -> do
        forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        let ([VName]
_x_params, [VName]
y_params) =
              forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) -> forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []
        AtomicBinOp
atomics <- KernelEnv -> AtomicBinOp
kernelAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
        case AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
lam of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> do
            Maybe Locks
c_locks <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> Map VName Locks
kernelLocks forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
            case Maybe Locks
c_locks of
              Just (Locks VName
locks Int
num_locks) -> do
                let locking :: Locking
locking =
                      VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 forall a b. (a -> b) -> a -> b
$
                        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims
                Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
              Maybe Locks
Nothing ->
                forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Missing locks for " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
acc

compileThreadExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileThreadExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) [] SubExp
se []
compileThreadExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase Shape NoUniqueness
_)) =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileThreadExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) =
  VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
compileThreadExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

-- | Assign iterations of a for-loop to all threads in the kernel.
-- The passed-in function is invoked with the (symbolic) iteration.
-- The body must contain thread-level code.  For multidimensional
-- loops, use 'groupCoverSpace'.
kernelLoop ::
  IntExp t =>
  Imp.TExp t ->
  Imp.TExp t ->
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
kernelLoop :: forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp t
tid TExp t
num_threads TExp t
n TExp t -> InKernelGen ()
f =
  forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    if TExp t
n forall a. Eq a => a -> a -> Bool
== TExp t
num_threads
      then TExp t -> InKernelGen ()
f TExp t
tid
      else do
        TExp t
num_chunks <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ TExp t
n forall e. IntegralExp e => e -> e -> e
`divUp` TExp t
num_threads
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp t
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp t
chunk_i -> do
          TExp t
i <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" forall a b. (a -> b) -> a -> b
$ TExp t
chunk_i forall a. Num a => a -> a -> a
* TExp t
num_threads forall a. Num a => a -> a -> a
+ TExp t
tid
          forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp t
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp t
n) forall a b. (a -> b) -> a -> b
$ TExp t -> InKernelGen ()
f TExp t
i

-- | Assign iterations of a for-loop to threads in the workgroup.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'groupCoverSpace'.
groupLoop ::
  IntExp t =>
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
groupLoop :: forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp t
n TExp t -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop
    (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    TExp t
n
    TExp t -> InKernelGen ()
f

-- | Iterate collectively though a multidimensional space, such that
-- all threads in the group participate.  The passed-in function is
-- invoked with a (symbolic) point in the index space.
groupCoverSpace ::
  IntExp t =>
  [Imp.TExp t] ->
  ([Imp.TExp t] -> InKernelGen ()) ->
  InKernelGen ()
groupCoverSpace :: forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp t]
ds [TExp t] -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  case forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
1 [TExp t]
ds of
    -- Optimise the case where the inner dimension of the space is
    -- equal to the group size.
    ([TExp t]
ds', [TExp t
last_d])
      | TExp t
last_d forall a. Eq a => a -> a -> Bool
== (TPrimExp Int64 VName
group_size forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d) -> do
          let ltid :: TExp t
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k1} {k2} (to :: k1) (from :: k2) v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace [TExp t]
ds' forall a b. (a -> b) -> a -> b
$ \[TExp t]
ds_is ->
            [TExp t] -> InKernelGen ()
f forall a b. (a -> b) -> a -> b
$ [TExp t]
ds_is forall a. [a] -> [a] -> [a]
++ [TExp t
ltid]
    ([TExp t], [TExp t])
_ ->
      forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp t]
ds) forall a b. (a -> b) -> a -> b
$ [TExp t] -> InKernelGen ()
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp t]
ds

-- Which fence do we need to protect shared access to this memory space?
fenceForSpace :: Space -> Imp.Fence
fenceForSpace :: Space -> Fence
fenceForSpace (Space [Char]
"local") = Fence
Imp.FenceLocal
fenceForSpace Space
_ = Fence
Imp.FenceGlobal

-- | If we are touching these arrays, which kind of fence do we need?
fenceForArrays :: [VName] -> InKernelGen Imp.Fence
fenceForArrays :: [VName] -> InKernelGen Fence
fenceForArrays = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Ord a => a -> a -> a
max Fence
Imp.FenceLocal) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {rep :: k} {r} {op}. VName -> ImpM rep r op Fence
need
  where
    need :: VName -> ImpM rep r op Fence
need VName
arr =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Fence
fenceForSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemLoc -> VName
memLocName
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr

inBlockScan ::
  KernelConstants ->
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  Imp.TExp Bool ->
  [VName] ->
  InKernelGen () ->
  Lambda GPUMem ->
  InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TExp Int32
lockstep_width TExp Int32
block_size TExp Bool
active [VName]
arrs InKernelGen ()
barrier Lambda GPUMem
scan_lam = forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ do
  TV Int32
skip_threads <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"skip_threads" PrimType
int32
  let actual_params :: [LParam GPUMem]
actual_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
      ([Param LParamMem]
x_params, [Param LParamMem]
y_params) =
        forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
actual_params forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
actual_params
      y_to_x :: InKernelGen ()
y_to_x =
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType (forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x)) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param LParamMem
y)) []

  -- Set initial y values
  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read input for in-block scan" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
active forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readInitial [Param LParamMem]
y_params [VName]
arrs
      -- Since the final result is expected to be in x_params, we may
      -- need to copy it there for the first thread in the block.
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) InKernelGen ()
y_to_x

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier

  let op_to_x :: TExp Bool -> InKernelGen ()
op_to_x TExp Bool
in_block_thread_active
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active forall a b. (a -> b) -> a -> b
$
              forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"inactive" forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
ltid32 forall a. Num a => a -> a -> a
- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads) TExp Int32
ltid32
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
in_block_thread_active forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
inactive) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
                forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive forall a b. (a -> b) -> a -> b
$
              forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam

      maybeBarrier :: InKernelGen ()
maybeBarrier =
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
          (TExp Int32
lockstep_width forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads)
          InKernelGen ()
barrier

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"in-block scan (hopefully no barriers needed)" forall a b. (a -> b) -> a -> b
$ do
    TV Int32
skip_threads forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) forall a b. (a -> b) -> a -> b
$ do
      TExp Bool
thread_active <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_active" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
active

      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (TPrimExp Int64 VName -> Param LParamMem -> VName -> InKernelGen ()
readParam (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads)) [Param LParamMem]
x_params [VName]
arrs
      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen ()
op_to_x TExp Bool
thread_active

      InKernelGen ()
maybeBarrier

      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write result" forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$
          forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs

      InKernelGen ()
maybeBarrier

      TV Int32
skip_threads forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall a. Num a => a -> a -> a
* TExp Int32
2
  where
    block_id :: TExp Int32
block_id = TExp Int32
ltid32 forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
    in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 forall a. Num a => a -> a -> a
- TExp Int32
block_id forall a. Num a => a -> a -> a
* TExp Int32
block_size
    ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
    ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
    gtid :: TPrimExp Int64 VName
gtid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
    array_scan :: Bool
array_scan = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam

    readInitial :: Param LParamMem -> VName -> InKernelGen ()
readInitial Param LParamMem
p VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid]
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]

    readParam :: TPrimExp Int64 VName -> Param LParamMem -> VName -> InKernelGen ()
readParam TPrimExp Int64 VName
behind Param LParamMem
p VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind]
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gtid forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
arrs_full_size]

    writeResult :: Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult Param LParamMem
x Param LParamMem
y VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x = do
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
x) []
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
x) []
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
x) []

groupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  Lambda GPUMem
renamed_lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
lam

  let ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
      ([Param LParamMem]
x_params, [Param LParamMem]
y_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
renamed_lam)

  TExp Bool
ltid_in_bounds <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"ltid_in_bounds" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
w

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays [VName]
arrs

  -- The scan works by splitting the group into blocks, which are
  -- scanned separately.  Typically, these blocks are smaller than
  -- the lockstep width, which enables barrier-free execution inside
  -- them.
  --
  -- We hardcode the block size here.  The only requirement is that
  -- it should not be less than the square root of the group size.
  -- With 32, we will work on groups of size 1024 or smaller, which
  -- fits every device Troels has seen.  Still, it would be nicer if
  -- it were a runtime parameter.  Some day.
  let block_size :: TExp Int32
block_size = TExp Int32
32
      simd_width :: TExp Int32
simd_width = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      block_id :: TExp Int32
block_id = TExp Int32
ltid32 forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
      in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 forall a. Num a => a -> a -> a
- TExp Int32
block_id forall a. Num a => a -> a -> a
* TExp Int32
block_size
      doInBlockScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag' TExp Bool
active =
        KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inBlockScan
          KernelConstants
constants
          Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag'
          TPrimExp Int64 VName
arrs_full_size
          TExp Int32
simd_width
          TExp Int32
block_size
          TExp Bool
active
          [VName]
arrs
          InKernelGen ()
barrier
      array_scan :: Bool
array_scan = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam
      barrier :: InKernelGen ()
barrier
        | Bool
array_scan =
            forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
        | Bool
otherwise =
            forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence

      group_offset :: TPrimExp Int64 VName
group_offset = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants) forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants

      writeBlockResult :: Param LParamMem -> VName -> InKernelGen ()
writeBlockResult Param LParamMem
p VName
arr
        | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []
        | Bool
otherwise =
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []

      readPrevBlockResult :: Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult Param LParamMem
p VName
arr
        | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
        | Bool
otherwise =
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]

  Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Bool
ltid_in_bounds Lambda GPUMem
lam
  InKernelGen ()
barrier

  let is_first_block :: TExp Bool
is_first_block = TExp Int32
block_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save correct values for first block" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
x) []

    InKernelGen ()
barrier

  let last_in_block :: TExp Bool
last_in_block = TExp Int32
in_block_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
block_size forall a. Num a => a -> a -> a
- TExp Int32
1
  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"last thread of block 'i' writes its result to offset 'i'" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
last_in_block forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeBlockResult [Param LParamMem]
x_params [VName]
arrs

  InKernelGen ()
barrier

  let first_block_seg_flag :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag = do
        TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
          TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
from forall a. Num a => a -> a -> a
* TExp Int32
block_size forall a. Num a => a -> a -> a
+ TExp Int32
block_size forall a. Num a => a -> a -> a
- TExp Int32
1) (TExp Int32
to forall a. Num a => a -> a -> a
* TExp Int32
block_size forall a. Num a => a -> a -> a
+ TExp Int32
block_size forall a. Num a => a -> a -> a
- TExp Int32
1)
  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment
    Text
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'"
    forall a b. (a -> b) -> a -> b
$ Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag (TExp Bool
is_first_block forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) Lambda GPUMem
renamed_lam

  InKernelGen ()
barrier

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"move correct values for first block back a block" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
              VName
arr
              [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]
              (VName -> SubExp
Var VName
arr)
              [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

    InKernelGen ()
barrier

  TExp Bool
no_carry_in <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"no_carry_in" forall a b. (a -> b) -> a -> b
$ TExp Bool
is_first_block forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
ltid_in_bounds

  let read_carry_in :: InKernelGen ()
read_carry_in = forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param LParamMem
x)) []
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult [Param LParamMem]
x_params [VName]
arrs

      op_to_x :: InKernelGen ()
op_to_x
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in forall a b. (a -> b) -> a -> b
$ forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"inactive" forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
block_id forall a. Num a => a -> a -> a
* TExp Int32
block_size forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32
ltid32
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
inactive forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
              forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive forall a b. (a -> b) -> a -> b
$ forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

      write_final_result :: InKernelGen ()
write_final_result =
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"carry-in for every block except the first" forall a b. (a -> b) -> a -> b
$ do
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" InKernelGen ()
read_carry_in
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" InKernelGen ()
op_to_x
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write final result" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in InKernelGen ()
write_final_result

  InKernelGen ()
barrier

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"restore correct values for first block" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
is_first_block forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y, VName
arr) ->
        if forall shape u. TypeBase shape u -> Bool
primType (forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
y)
          then forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
y) []
          else forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

  InKernelGen ()
barrier

groupReduce ::
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupReduce :: TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TV Int32
offset <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"offset" PrimType
int32
  TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs

groupReduceWithOffset ::
  TV Int32 ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupReduceWithOffset :: TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

  let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      global_tid :: TExp Int32
global_tid = KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants

      barrier :: InKernelGen ()
barrier
        | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam = forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        | Bool
otherwise = forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      readReduceArgument :: Param LParamMem -> VName -> InKernelGen ()
readReduceArgument Param LParamMem
param VName
arr
        | Prim PrimType
_ <- forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
param = do
            let i :: TExp Int32
i = TExp Int32
local_tid forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
        | Bool
otherwise = do
            let i :: TExp Int32
i = TExp Int32
global_tid forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]

      writeReduceOpResult :: Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult Param LParamMem
param VName
arr
        | Prim PrimType
_ <- forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
param =
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
param) []
        | Bool
otherwise =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  let ([Param LParamMem]
reduce_acc_params, [Param LParamMem]
reduce_arr_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
skip_waves <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"skip_waves" (TExp Int32
1 :: Imp.TExp Int32)
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
offset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"participating threads read initial accumulator" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_acc_params [VName]
arrs

  let do_reduce :: InKernelGen ()
do_reduce = do
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"read array element" forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_arr_params [VName]
arrs
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"apply reduction operation" forall a b. (a -> b) -> a -> b
$
          forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
reduce_acc_params forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"write result of operation" forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult [Param LParamMem]
reduce_acc_params [VName]
arrs
      in_wave_reduce :: InKernelGen ()
in_wave_reduce = forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile InKernelGen ()
do_reduce

      wave_size :: TExp Int32
wave_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      wave_id :: TExp Int32
wave_id = TExp Int32
local_tid forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      in_wave_id :: TExp Int32
in_wave_id = TExp Int32
local_tid forall a. Num a => a -> a -> a
- TExp Int32
wave_id forall a. Num a => a -> a -> a
* TExp Int32
wave_size
      num_waves :: TExp Int32
num_waves = (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
+ TExp Int32
wave_size forall a. Num a => a -> a -> a
- TExp Int32
1) forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      arg_in_bounds :: TExp Bool
arg_in_bounds = TExp Int32
local_tid forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w

      doing_in_wave_reductions :: TExp Bool
doing_in_wave_reductions =
        forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
wave_size
      apply_in_in_wave_iteration :: TExp Bool
apply_in_in_wave_iteration =
        (TExp Int32
in_wave_id forall {k} (t :: k) v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 forall a. Num a => a -> a -> a
* forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset forall a. Num a => a -> a -> a
- TExp Int32
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
        TV Int32
offset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32)
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_in_wave_reductions forall a b. (a -> b) -> a -> b
$ do
          forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
            (TExp Bool
arg_in_bounds forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
apply_in_in_wave_iteration)
            InKernelGen ()
in_wave_reduce
          TV Int32
offset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
offset forall a. Num a => a -> a -> a
* TExp Int32
2

      doing_cross_wave_reductions :: TExp Bool
doing_cross_wave_reductions =
        forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
num_waves
      is_first_thread_in_wave :: TExp Bool
is_first_thread_in_wave =
        TExp Int32
in_wave_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      wave_not_skipped :: TExp Bool
wave_not_skipped =
        (TExp Int32
wave_id forall {k} (t :: k) v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 forall a. Num a => a -> a -> a
* forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves forall a. Num a => a -> a -> a
- TExp Int32
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      apply_in_cross_wave_iteration :: TExp Bool
apply_in_cross_wave_iteration =
        TExp Bool
arg_in_bounds forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
is_first_thread_in_wave forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
wave_not_skipped
      cross_wave_reductions :: InKernelGen ()
cross_wave_reductions =
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_cross_wave_reductions forall a b. (a -> b) -> a -> b
$ do
          InKernelGen ()
barrier
          TV Int32
offset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves forall a. Num a => a -> a -> a
* TExp Int32
wave_size
          forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
            TExp Bool
apply_in_cross_wave_iteration
            InKernelGen ()
do_reduce
          TV Int32
skip_waves forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_waves forall a. Num a => a -> a -> a
* TExp Int32
2

  InKernelGen ()
in_wave_reductions
  InKernelGen ()
cross_wave_reductions

compileThreadOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LParamMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
pat SubExp
size Space
space
compileThreadOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$ [Char]
"compileThreadOp: cannot compile rhs of binding " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
pat

-- | Locking strategy used for an atomic update.
data Locking = Locking
  { -- | Array containing the lock.
    Locking -> VName
lockingArray :: VName,
    -- | Value for us to consider the lock free.
    Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
    -- | What to write when we lock it.
    Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
    -- | What to write when we unlock it.
    Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
    -- | A transformation from the logical lock index to the
    -- physical position in the array.  This can also be used
    -- to make the lock array smaller.
    Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
  }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate rep r =
  Space -> [VName] -> [Imp.TExp Int64] -> ImpM rep r Imp.KernelOp ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate rep r
  = -- | Supported directly by primitive.
    AtomicPrim (DoAtomicUpdate rep r)
  | -- | Can be done by efficient swaps.
    AtomicCAS (DoAtomicUpdate rep r)
  | -- | Requires explicit locking.
    AtomicLocking (Locking -> DoAtomicUpdate rep r)

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Count Imp.Elements (Imp.TExp Int64) -> Imp.Exp -> Imp.AtomicOp)

-- | Do an atomic update corresponding to a binary operator lambda.
atomicUpdateLocking ::
  AtomicBinOp ->
  Lambda GPUMem ->
  AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking :: AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda GPUMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- forall {k} (rep :: k).
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda GPUMem
lam,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64]) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
      [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket ->
        -- If the operator is a vectorised binary operator on 32/64-bit
        -- values, we can use a particularly efficient
        -- implementation. If the operator has an atomic implementation
        -- we use that, otherwise it is still a binary operator which
        -- can be implemented by atomic compare-and-swap if 32/64 bits.
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
          -- Common variables.
          TV Any
old <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"old" PrimType
t

          (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
a [TPrimExp Int64 VName]
bucket

          case Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space (forall {k} (t :: k). TV t -> VName
tvVar TV Any
old) VName
arr' Count Elements (TPrimExp Int64 VName)
bucket_offset BinOp
op of
            Just Exp -> KernelOp
f -> forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Exp -> KernelOp
f forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
            Maybe (Exp -> KernelOp)
Nothing ->
              Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a (forall {k} (t :: k). TV t -> VName
tvVar TV Any
old) [TPrimExp Int64 VName]
bucket VName
x forall a b. (a -> b) -> a -> b
$
                VName
x forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)
  where
    opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket' BinOp
bop = do
      let atomic :: (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> Exp
 -> AtomicOp)
-> Exp -> KernelOp
atomic VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
f VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket'
      (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> Exp
 -> AtomicOp)
-> Exp -> KernelOp
atomic forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

    primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = forall {k} (rep :: k) r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicPrim
      | Bool
otherwise = forall {k} (rep :: k) r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS

    isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op

-- If the operator functions purely on single 32/64-bit values, we can
-- use an implementation based on CAS, no matter what the operator
-- does.
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op
  | [Prim PrimType
t] <- forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
op,
    [LParam GPUMem
xp, LParam GPUMem
_] <- forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op,
    PrimType -> Int
primBitSize PrimType
t forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64] = forall {k} (rep :: k) r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [TPrimExp Int64 VName]
bucket -> do
      TV Any
old <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"old" PrimType
t
      Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr (forall {k} (t :: k). TV t -> VName
tvVar TV Any
old) [TPrimExp Int64 VName]
bucket (forall dec. Param dec -> VName
paramName LParam GPUMem
xp) forall a b. (a -> b) -> a -> b
$
        forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [LParam GPUMem
xp] forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op = forall {k} (rep :: k) r.
(Locking -> DoAtomicUpdate rep r) -> AtomicUpdate rep r
AtomicLocking forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket -> do
  TV Int32
old <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"old" PrimType
int32
  TV Bool
continue <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
"continue" PrimType
Bool forall v. TPrimExp Bool v
true

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements (TPrimExp Int64 VName)
locks_offset) <-
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) forall a b. (a -> b) -> a -> b
$ Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping Locking
locking [TPrimExp Int64 VName]
bucket

  -- Critical section
  let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingIsUnlocked Locking
locking)
              (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
      lock_acquired :: TExp Bool
lock_acquired = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
old forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Locking -> TExp Int32
lockingIsUnlocked Locking
locking
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: InKernelGen ()
release_lock =
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
              (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToUnlock Locking
locking)
      break_loop :: InKernelGen ()
break_loop = TV Bool
continue forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall v. TPrimExp Bool v
false

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  --
  -- Note the use of 'everythingVolatile' when reading and writing the
  -- buckets.  This was necessary to ensure correct execution on a
  -- newer NVIDIA GPU (RTX 2080).  The 'volatile' modifiers likely
  -- make the writes pass through the (SM-local) L1 cache, which is
  -- necessary here, because we are really doing device-wide
  -- synchronisation without atomics (naughty!).
  let ([Param LParamMem]
acc_params, [Param LParamMem]
_arr_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op
      bind_acc_params :: InKernelGen ()
bind_acc_params =
        forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"bind lhs" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, VName
arr) ->
              forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  let op_body :: InKernelGen ()
op_body =
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"execute operation" forall a b. (a -> b) -> a -> b
$
          forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
acc_params forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op

      do_hist :: InKernelGen ()
do_hist =
        forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"update global result" forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (forall {k} {rep :: k} {r} {op}.
[TPrimExp Int64 VName] -> VName -> SubExp -> ImpM rep r op ()
writeArray [TPrimExp Int64 VName]
bucket) [VName]
arrs forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param LParamMem]
acc_params

      fence :: InKernelGen ()
fence = forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence forall a b. (a -> b) -> a -> b
$ Space -> Fence
fenceForSpace Space
space

  -- While-loop: Try to insert your value
  forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
continue) forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
try_acquire_lock
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
lock_acquired forall a b. (a -> b) -> a -> b
$ do
      forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams [Param LParamMem]
acc_params
      InKernelGen ()
bind_acc_params
      InKernelGen ()
op_body
      InKernelGen ()
do_hist
      InKernelGen ()
fence
      InKernelGen ()
release_lock
      InKernelGen ()
break_loop
    InKernelGen ()
fence
  where
    writeArray :: [TPrimExp Int64 VName] -> VName -> SubExp -> ImpM rep r op ()
writeArray [TPrimExp Int64 VName]
bucket VName
arr SubExp
val = forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
bucket SubExp
val []

atomicUpdateCAS ::
  Space ->
  PrimType ->
  VName ->
  VName ->
  [Imp.TExp Int64] ->
  VName ->
  InKernelGen () ->
  InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [TPrimExp Int64 VName]
bucket VName
x InKernelGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  VName
assumed <- forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"assumed" PrimType
t
  TV Bool
run_loop <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"run_loop" forall v. TPrimExp Bool v
true

  -- XXX: CUDA may generate really bad code if this is not a volatile
  -- read.  Unclear why.  The later reads are volatile, so maybe
  -- that's it.
  forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
arr [TPrimExp Int64 VName]
bucket

  -- While-loop: Try to insert your value
  let (Exp -> Exp
toBits, Exp -> Exp
fromBits) =
        case PrimType
t of
          FloatType FloatType
Float16 ->
            ( \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits16" [Exp
v] PrimType
int16,
              \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits16" [Exp
v] PrimType
t
            )
          FloatType FloatType
Float32 ->
            ( \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits32" [Exp
v] PrimType
int32,
              \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits32" [Exp
v] PrimType
t
            )
          FloatType FloatType
Float64 ->
            ( \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"to_bits64" [Exp
v] PrimType
int64,
              \Exp
v -> forall v. [Char] -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp [Char]
"from_bits64" [Exp
v] PrimType
t
            )
          PrimType
_ -> (forall a. a -> a
id, forall a. a -> a
id)

      int :: PrimType
int
        | PrimType -> Int
primBitSize PrimType
t forall a. Eq a => a -> a -> Bool
== Int
16 = PrimType
int16
        | PrimType -> Int
primBitSize PrimType
t forall a. Eq a => a -> a -> Bool
== Int
32 = PrimType
int32
        | Bool
otherwise = PrimType
int64

  forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
run_loop) forall a b. (a -> b) -> a -> b
$ do
    VName
assumed forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    VName
x forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t
    InKernelGen ()
do_op
    VName
old_bits_v <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"old_bits"
    forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
old_bits_v PrimType
int
    let old_bits :: Exp
old_bits = VName -> PrimType -> Exp
Imp.var VName
old_bits_v PrimType
int
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space forall a b. (a -> b) -> a -> b
$
      PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
        PrimType
int
        VName
old_bits_v
        VName
arr'
        Count Elements (TPrimExp Int64 VName)
bucket_offset
        (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t))
        (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))
    VName
old forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp -> Exp
fromBits Exp
old_bits
    let won :: Exp
won = forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq PrimType
int) (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t)) Exp
old_bits
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall v. PrimExp v -> TPrimExp Bool v
isBool Exp
won) (TV Bool
run_loop forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall v. TPrimExp Bool v
false)

computeKernelUses ::
  FreeIn a =>
  a ->
  [VName] ->
  CallKernelGen [Imp.KernelUse]
computeKernelUses :: forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
  let actually_free :: Names
actually_free = forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
  -- Compute the variables that we need to pass to the kernel.
  forall a. Ord a => [a] -> [a]
nubOrd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {r} {op}. VName -> ImpM GPUMem r op (Maybe KernelUse)
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
  where
    f :: VName -> ImpM GPUMem r op (Maybe KernelUse)
f VName
var = do
      TypeBase Shape NoUniqueness
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
var
      VTable GPUMem
vtable <- forall {k} (rep :: k) r op. ImpM rep r op (VTable rep)
getVTable
      case TypeBase Shape NoUniqueness
t of
        Array {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Acc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Mem (Space [Char]
"local") -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Mem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
        Prim PrimType
bt ->
          forall {k} (rep :: k) r op.
VTable GPUMem -> Exp -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable (VName -> PrimType -> Exp
Imp.var VName
var PrimType
bt) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Just KernelConstExp
ce -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
            Maybe KernelConstExp
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt

isConstExp ::
  VTable GPUMem ->
  Imp.Exp ->
  ImpM rep r op (Maybe Imp.KernelConstExp)
isConstExp :: forall {k} (rep :: k) r op.
VTable GPUMem -> Exp -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable Exp
size = do
  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let onLeaf :: VName -> PrimType -> Maybe KernelConstExp
onLeaf VName
name PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
      lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
        Exp GPUMem -> Maybe KernelConstExp
constExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} {rep :: k}. VarEntry rep -> Maybe (Exp rep)
hasExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable GPUMem
vtable
      constExp :: Exp GPUMem -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize Name
key SizeClass
_)))) =
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
      constExp Exp GPUMem
e = forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp Exp GPUMem
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM VName -> PrimType -> Maybe KernelConstExp
onLeaf Exp
size
  where
    hasExp :: VarEntry rep -> Maybe (Exp rep)
hasExp (ArrayVar Maybe (Exp rep)
e ArrayEntry
_) = Maybe (Exp rep)
e
    hasExp (AccVar Maybe (Exp rep)
e (VName, Shape, [TypeBase Shape NoUniqueness])
_) = Maybe (Exp rep)
e
    hasExp (ScalarVar Maybe (Exp rep)
e ScalarEntry
_) = Maybe (Exp rep)
e
    hasExp (MemVar Maybe (Exp rep)
e MemEntry
_) = Maybe (Exp rep)
e

kernelInitialisationSimple ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size = do
  VName
global_tid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"global_tid"
  VName
local_tid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"local_tid"
  VName
group_id <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"group_tid"
  VName
wave_size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"wave_size"
  VName
inner_group_size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"group_size"
  let num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups)
      group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
      constants :: KernelConstants
constants =
        KernelConstants
          { kernelGlobalThreadId :: TExp Int32
kernelGlobalThreadId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
global_tid,
            kernelLocalThreadId :: TExp Int32
kernelLocalThreadId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
local_tid,
            kernelGroupId :: TExp Int32
kernelGroupId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
group_id,
            kernelGlobalThreadIdVar :: VName
kernelGlobalThreadIdVar = VName
global_tid,
            kernelLocalThreadIdVar :: VName
kernelLocalThreadIdVar = VName
local_tid,
            kernelNumGroupsCount :: Count NumGroups SubExp
kernelNumGroupsCount = Count NumGroups SubExp
num_groups,
            kernelGroupSizeCount :: Count GroupSize SubExp
kernelGroupSizeCount = Count GroupSize SubExp
group_size,
            kernelGroupIdVar :: VName
kernelGroupIdVar = VName
group_id,
            kernelNumGroups :: TPrimExp Int64 VName
kernelNumGroups = TPrimExp Int64 VName
num_groups',
            kernelGroupSize :: TPrimExp Int64 VName
kernelGroupSize = TPrimExp Int64 VName
group_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_groups'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
wave_size,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = forall a. Monoid a => a
mempty
          }

  let set_constants :: InKernelGen ()
set_constants = do
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
local_tid PrimType
int32
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_group_size PrimType
int64
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
wave_size PrimType
int32
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
group_id PrimType
int32

        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
global_tid forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int32 a
le32 VName
group_id forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int32 a
le32 VName
local_tid

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelConstants
constants, InKernelGen ()
set_constants)

isActive :: [(VName, SubExp)] -> Imp.TExp Bool
isActive :: [(VName, SubExp)] -> TExp Bool
isActive [(VName, SubExp)]
limit = case [TExp Bool]
actives of
  [] -> forall v. TPrimExp Bool v
true
  TExp Bool
x : [TExp Bool]
xs -> forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
x [TExp Bool]
xs
  where
    ([VName]
is, [SubExp]
ws) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
    actives :: [TExp Bool]
actives = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {v}. v -> TPrimExp Int64 v -> TPrimExp Bool v
active [VName]
is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws
    active :: v -> TPrimExp Int64 v -> TPrimExp Bool v
active v
i = (forall a. a -> TPrimExp Int64 a
Imp.le64 v
i .<.)

-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and declared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  forall {k} (rep :: k) r op a.
Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace ([Char] -> Space
Imp.Space [Char]
"global") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable (forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall {k} {rep :: k}. VarEntry rep -> VarEntry rep
globalMemory)
  where
    globalMemory :: VarEntry rep -> VarEntry rep
globalMemory (MemVar Maybe (Exp rep)
_ MemEntry
entry)
      | MemEntry -> Space
entryMemSpace MemEntry
entry forall a. Eq a => a -> a -> Bool
/= [Char] -> Space
Space [Char]
"local" =
          forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing MemEntry
entry {entryMemSpace :: Space
entryMemSpace = [Char] -> Space
Imp.Space [Char]
"global"}
    globalMemory VarEntry rep
entry =
      VarEntry rep
entry

simpleKernelGroups ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  CallKernelGen (Imp.TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups TPrimExp Int64 VName
max_num_groups TPrimExp Int64 VName
kernel_size = do
  TV Int64
group_size <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"group_size" PrimType
int64
  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Int64
group_size
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
group_size) Name
group_size_key SizeClass
Imp.SizeGroup
  TPrimExp Int64 VName
virt_num_groups <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
kernel_size forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
group_size
  TV Int64
num_groups <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_groups" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
virt_num_groups forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMin64` TPrimExp Int64 VName
max_num_groups
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
virt_num_groups, forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_groups, forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
group_size)

simpleKernelConstants ::
  Imp.TExp Int64 ->
  String ->
  CallKernelGen
    ( (Imp.TExp Int64 -> InKernelGen ()) -> InKernelGen (),
      KernelConstants
    )
simpleKernelConstants :: TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
kernel_size [Char]
desc = do
  -- For performance reasons, codegen assumes that the thread count is
  -- never more than will fit in an i32.  This means we need to cap
  -- the number of groups here.  The cap is set much higher than any
  -- GPU will possibly need.  Feel free to come back and laugh at me
  -- in the future.
  let max_num_groups :: TPrimExp Int64 VName
max_num_groups = TPrimExp Int64 VName
1024 forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
1024
  VName
thread_gtid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ [Char]
desc forall a. [a] -> [a] -> [a]
++ [Char]
"_gtid"
  VName
thread_ltid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ [Char]
desc forall a. [a] -> [a] -> [a]
++ [Char]
"_ltid"
  VName
group_id <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ [Char]
desc forall a. [a] -> [a] -> [a]
++ [Char]
"_gid"
  VName
inner_group_size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"group_size"
  (TExp Int32
virt_num_groups, Count NumGroups SubExp
num_groups, Count GroupSize SubExp
group_size) <-
    TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups TPrimExp Int64 VName
max_num_groups TPrimExp Int64 VName
kernel_size
  let group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
      num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
Imp.pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups

      constants :: KernelConstants
constants =
        KernelConstants
          { kernelGlobalThreadId :: TExp Int32
kernelGlobalThreadId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
thread_gtid,
            kernelLocalThreadId :: TExp Int32
kernelLocalThreadId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
thread_ltid,
            kernelGroupId :: TExp Int32
kernelGroupId = forall a. a -> TPrimExp Int32 a
Imp.le32 VName
group_id,
            kernelGlobalThreadIdVar :: VName
kernelGlobalThreadIdVar = VName
thread_gtid,
            kernelLocalThreadIdVar :: VName
kernelLocalThreadIdVar = VName
thread_ltid,
            kernelGroupIdVar :: VName
kernelGroupIdVar = VName
group_id,
            kernelNumGroupsCount :: Count NumGroups SubExp
kernelNumGroupsCount = Count NumGroups SubExp
num_groups,
            kernelGroupSizeCount :: Count GroupSize SubExp
kernelGroupSizeCount = Count GroupSize SubExp
group_size,
            kernelNumGroups :: TPrimExp Int64 VName
kernelNumGroups = TPrimExp Int64 VName
num_groups',
            kernelGroupSize :: TPrimExp Int64 VName
kernelGroupSize = TPrimExp Int64 VName
group_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_groups'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = TExp Int32
0,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = forall a. Monoid a => a
mempty
          }

      wrapKernel :: (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel TPrimExp Int64 VName -> InKernelGen ()
m = do
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
thread_ltid PrimType
int32
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_group_size PrimType
int64
        forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
group_id PrimType
int32
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
thread_gtid forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int32 a
le32 VName
group_id forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int32 a
le32 VName
thread_ltid
        SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
virt_num_groups forall a b. (a -> b) -> a -> b
$ \TExp Int32
virt_group_id -> do
          TPrimExp Int64 VName
global_tid <-
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_tid" forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
virt_group_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size)
                forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
          TPrimExp Int64 VName -> InKernelGen ()
m TPrimExp Int64 VName
global_tid

  forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel, KernelConstants
constants)

-- | For many kernels, we may not have enough physical groups to cover
-- the logical iteration space.  Some groups thus have to perform
-- double duty; we put an outer loop to accomplish this.  The
-- advantage over just launching a bazillion threads is that the cost
-- of memory expansion should be proportional to the number of
-- *physical* threads (hardware parallelism), not the amount of
-- application parallelism.
virtualiseGroups ::
  SegVirt ->
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> InKernelGen ()) ->
  InKernelGen ()
virtualiseGroups :: SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups TExp Int32 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  TV Int32
phys_group_id <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"phys_group_id" PrimType
int32
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
phys_group_id) Int
0
  TExp Int32
iterations <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" forall a b. (a -> b) -> a -> b
$
      (TExp Int32
required_groups forall a. Num a => a -> a -> a
- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
phys_group_id) forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants)

  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" TExp Int32
iterations forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
    TExp Int32 -> InKernelGen ()
m forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> TExp t
tvExp
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV
        [Char]
"virt_group_id"
        (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
phys_group_id forall a. Num a => a -> a -> a
+ TExp Int32
i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants))
    -- Make sure the virtual group is actually done before we let
    -- another virtual group have its way with it.
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
virtualiseGroups SegVirt
_ TExp Int32
_ TExp Int32 -> InKernelGen ()
m = do
  VName
gid <- KernelConstants -> VName
kernelGroupIdVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  TExp Int32 -> InKernelGen ()
m forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int32 a
Imp.le32 VName
gid

-- | Various extra configuration of the kernel being generated.
data KernelAttrs = KernelAttrs
  { -- | Can this kernel execute correctly even if previous kernels failed?
    KernelAttrs -> Bool
kAttrFailureTolerant :: Bool,
    -- | Does whatever launch this kernel check for local memory capacity itself?
    KernelAttrs -> Bool
kAttrCheckLocalMemory :: Bool,
    -- | Number of groups.
    KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups :: Count NumGroups SubExp,
    -- | Group size.
    KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize :: Count GroupSize SubExp
  }

-- | The default kernel attributes.
defKernelAttrs ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  KernelAttrs
defKernelAttrs :: Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size =
  KernelAttrs
    { kAttrFailureTolerant :: Bool
kAttrFailureTolerant = Bool
False,
      kAttrCheckLocalMemory :: Bool
kAttrCheckLocalMemory = Bool
True,
      kAttrNumGroups :: Count NumGroups SubExp
kAttrNumGroups = Count NumGroups SubExp
num_groups,
      kAttrGroupSize :: Count GroupSize SubExp
kAttrGroupSize = Count GroupSize SubExp
group_size
    }

sKernel ::
  Operations GPUMem KernelEnv Imp.KernelOp ->
  (KernelConstants -> Imp.TExp Int32) ->
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernel :: Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernel Operations GPUMem KernelEnv KernelOp
ops KernelConstants -> TExp Int32
flatf [Char]
name VName
v KernelAttrs
attrs InKernelGen ()
f = do
  (KernelConstants
constants, InKernelGen ()
set_constants) <-
    Count NumGroups SubExp
-> Count GroupSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups KernelAttrs
attrs) (KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs)
  Name
name' <- forall {k} (rep :: k) r op. [Char] -> ImpM rep r op Name
nameForFun forall a b. (a -> b) -> a -> b
$ [Char]
name forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (VName -> Int
baseTag VName
v)
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name' forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
flatf KernelConstants
constants
    InKernelGen ()
f

sKernelThread ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelThread :: [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelThread = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernel Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants -> TExp Int32
kernelGlobalThreadId

sKernelOp ::
  KernelAttrs ->
  KernelConstants ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelOp :: KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m = do
  HostEnv AtomicBinOp
atomics Target
_ Map VName Locks
locks <- forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  Code KernelOp
body <- forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal forall a b. (a -> b) -> a -> b
$ forall {k} r' (rep :: k) op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> Map VName Locks -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants Map VName Locks
locks) Operations GPUMem KernelEnv KernelOp
ops InKernelGen ()
m
  [KernelUse]
uses <- forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body forall a. Monoid a => a
mempty
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Code a
Imp.Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel forall a b. (a -> b) -> a -> b
$
    Imp.Kernel
      { kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body,
        kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses,
        kernelNumGroups :: [Exp]
Imp.kernelNumGroups = [forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants],
        kernelGroupSize :: [Exp]
Imp.kernelGroupSize = [forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants],
        kernelName :: Name
Imp.kernelName = Name
name,
        kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = KernelAttrs -> Bool
kAttrFailureTolerant KernelAttrs
attrs,
        kernelCheckLocalMemory :: Bool
Imp.kernelCheckLocalMemory = KernelAttrs -> Bool
kAttrCheckLocalMemory KernelAttrs
attrs
      }

sKernelFailureTolerant ::
  Bool ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  KernelConstants ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations GPUMem KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m
  where
    attrs :: KernelAttrs
attrs =
      ( Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs
          (KernelConstants -> Count NumGroups SubExp
kernelNumGroupsCount KernelConstants
constants)
          (KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount KernelConstants
constants)
      )
        { kAttrFailureTolerant :: Bool
kAttrFailureTolerant = Bool
tol
        }

threadOperations :: Operations GPUMem KernelEnv Imp.KernelOp
threadOperations :: Operations GPUMem KernelEnv KernelOp
threadOperations =
  (forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileThreadExp,
      opsStmsCompiler :: StmsCompiler GPUMem KernelEnv KernelOp
opsStmsCompiler = \Names
_ -> forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char] -> Space
Space [Char]
"local", forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }

-- | Perform a Replicate with a kernel.
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
  TypeBase Shape NoUniqueness
t <- forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
se
  [SubExp]
ds <- forall a. Int -> [a] -> [a]
dropLast (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr

  let dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ [SubExp]
ds forall a. [a] -> [a] -> [a]
++ forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
  TPrimExp Int64 VName
n <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"replicate_n" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 [TPrimExp Int64 VName]
dims
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n [Char]
"replicate"

  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$
            [Char]
"replicate_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (VName -> Int
baseTag forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> do
      [TPrimExp Int64 VName]
is' <- forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"rep_i" [TPrimExp Int64 VName]
dims TPrimExp Int64 VName
gtid
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
se forall a b. (a -> b) -> a -> b
$
          forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [TPrimExp Int64 VName]
is'

replicateName :: PrimType -> String
replicateName :: PrimType -> [Char]
replicateName PrimType
bt = [Char]
"replicate_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
replicateName PrimType
bt

  Bool
exists <- forall {k} (rep :: k) r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem"
    VName
num_elems <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"num_elems"
    VName
val <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"val"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem ([Char] -> Space
Space [Char]
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int64,
            VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt
          ]
        shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
    forall {k} (rep :: k) r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        forall {k} (rep :: k) r op.
[Char]
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray [Char]
"arr" PrimType
bt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val

  forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp -> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
arr_shape IxFun
arr_ixfun) PrimType
_ <- forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase Shape NoUniqueness
v_t <- forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
v
  case TypeBase Shape NoUniqueness
v_t of
    Prim PrimType
v_t'
      | forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
arr_ixfun -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
            Name
fname <- PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
v_t'
            forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
              forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
                []
                Name
fname
                [ VName -> Arg
Imp.MemArg VName
arr_mem,
                  Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
arr_shape,
                  Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
v_t' SubExp
v
                ]
    TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- | Perform a Replicate with a kernel.
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
  -- If the replicate is of a particularly common and simple form
  -- (morally a memset()/fill), then we use a common function.
  Maybe (ImpM GPUMem HostEnv HostOp ())
is_fill <- VName
-> SubExp -> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se

  case Maybe (ImpM GPUMem HostEnv HostOp ())
is_fill of
    Just ImpM GPUMem HostEnv HostOp ()
m -> ImpM GPUMem HostEnv HostOp ()
m
    Maybe (ImpM GPUMem HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se

-- | Perform an Iota with a kernel.
sIotaKernel ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIotaKernel :: VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et = do
  MemLoc
destloc <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n [Char]
"iota"

  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$
            [Char]
"iota_"
              forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString IntType
et
              forall a. [a] -> [a] -> [a]
++ [Char]
"_"
              forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (VName -> Int
baseTag forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid ->
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$ do
        (VName
destmem, Space
destspace, Count Elements (TPrimExp Int64 VName)
destidx) <- forall {k} (rep :: k) r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
destloc [TPrimExp Int64 VName
gtid]

        forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
          forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TPrimExp Int64 VName)
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile forall a b. (a -> b) -> a -> b
$
            forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp
              (IntType -> Overflow -> BinOp
Add IntType
et Overflow
OverflowWrap)
              (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
et Overflow
OverflowWrap) (forall v. IntType -> PrimExp v -> PrimExp v
Imp.sExt IntType
et forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
gtid) Exp
s)
              Exp
x

iotaName :: IntType -> String
iotaName :: IntType -> [Char]
iotaName IntType
bt = [Char]
"iota_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString IntType
bt

iotaForType :: IntType -> CallKernelGen Name
iotaForType :: IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" forall a. Semigroup a => a -> a -> a
<> IntType -> [Char]
iotaName IntType
bt

  Bool
exists <- forall {k} (rep :: k) r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem"
    VName
n <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"n"
    VName
x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"x"
    VName
s <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"s"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem ([Char] -> Space
Space [Char]
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
n PrimType
int32,
            VName -> PrimType -> Param
Imp.ScalarParam VName
x forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt,
            VName -> PrimType -> Param
Imp.ScalarParam VName
s forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
          ]
        shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
n]
        n' :: TPrimExp Int64 VName
n' = forall a. a -> TPrimExp Int64 a
Imp.le64 VName
n
        x' :: Exp
x' = VName -> PrimType -> Exp
Imp.var VName
x forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
        s' :: Exp
s' = VName -> PrimType -> Exp
Imp.var VName
s forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt

    forall {k} (rep :: k) r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        forall {k} (rep :: k) r op.
[Char]
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray [Char]
"arr" (IntType -> PrimType
IntType IntType
bt) Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$
          forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
            forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
              forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
n') Exp
x' Exp
s' IntType
bt

  forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

-- | Perform an Iota with a kernel.
sIota ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIota :: VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIota VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
_ IxFun
arr_ixfun) PrimType
_ <- forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  if forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
arr_ixfun
    then do
      Name
fname <- IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
et
      forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
        forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
          [VName -> Arg
Imp.MemArg VName
arr_mem, Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
n, Exp -> Arg
Imp.ExpArg Exp
x, Exp -> Arg
Imp.ExpArg Exp
s]
    else VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr TPrimExp Int64 VName
n Exp
x Exp
s IntType
et

sCopy :: CopyCompiler GPUMem HostEnv Imp.HostOp
sCopy :: CopyCompiler GPUMem HostEnv HostOp
sCopy PrimType
pt destloc :: MemLoc
destloc@(MemLoc VName
destmem [SubExp]
_ IxFun
_) srcloc :: MemLoc
srcloc@(MemLoc VName
srcmem [SubExp]
srcdims IxFun
_) = do
  -- Note that the shape of the destination and the source are
  -- necessarily the same.
  let shape :: [TPrimExp Int64 VName]
shape = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
srcdims
      kernel_size :: TPrimExp Int64 VName
kernel_size = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
shape

  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
kernel_size [Char]
"copy"

  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$
            [Char]
"copy_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (VName -> Int
baseTag forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> do
      [TPrimExp Int64 VName]
is <- forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"copy_i" [TPrimExp Int64 VName]
shape TPrimExp Int64 VName
gtid

      (VName
_, Space
destspace, Count Elements (TPrimExp Int64 VName)
destidx) <- forall {k} (rep :: k) r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
destloc [TPrimExp Int64 VName]
is
      (VName
_, Space
srcspace, Count Elements (TPrimExp Int64 VName)
srcidx) <- forall {k} (rep :: k) r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
srcloc [TPrimExp Int64 VName]
is

      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
kernel_size) forall a b. (a -> b) -> a -> b
$ do
        VName
tmp <- forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"tmp" PrimType
pt
        forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TPrimExp Int64 VName)
srcidx PrimType
pt Space
srcspace Volatility
Imp.Nonvolatile
        forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TPrimExp Int64 VName)
destidx PrimType
pt Space
destspace Volatility
Imp.Nonvolatile forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt

-- | Perform a Rotate with a kernel.
sRotateKernel :: VName -> [Imp.TExp Int64] -> VName -> CallKernelGen ()
sRotateKernel :: VName
-> [TPrimExp Int64 VName] -> VName -> ImpM GPUMem HostEnv HostOp ()
sRotateKernel VName
dest [TPrimExp Int64 VName]
rs VName
src = do
  TypeBase Shape NoUniqueness
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
src
  let ds :: [TPrimExp Int64 VName]
ds = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
  TPrimExp Int64 VName
n <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rotate_n" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
ds
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> [Char]
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n [Char]
"rotate"

  Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname forall a b. (a -> b) -> a -> b
$
          [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$
            [Char]
"rotate_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (VName -> Int
baseTag forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$ do
      [TPrimExp Int64 VName]
is' <- forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"rep_i" [TPrimExp Int64 VName]
ds TPrimExp Int64 VName
gtid
      [TPrimExp Int64 VName]
is'' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 forall {k2} {rep :: k2} {r} {op}.
TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate [TPrimExp Int64 VName]
ds [TPrimExp Int64 VName]
rs [TPrimExp Int64 VName]
is'
      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName]
is' (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName]
is''
  where
    rotate :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate TPrimExp Int64 VName
d TPrimExp Int64 VName
r TPrimExp Int64 VName
i = forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rot_i" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
rotateIndex TPrimExp Int64 VName
d TPrimExp Int64 VName
r TPrimExp Int64 VName
i

compileThreadResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileThreadResult :: SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
_ PatElem LParamMem
_ RegTileReturns {} =
  forall a. [Char] -> a
compilerLimitationS [Char]
"compileThreadResult: RegTileReturns not yet handled."
compileThreadResult SegSpace
space PatElem LParamMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  let is :: [TPrimExp Int64 VName]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
Imp.le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName]
is SubExp
what []
compileThreadResult SegSpace
_ PatElem LParamMem
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_arr [(Slice SubExp, SubExp)]
dests) = do
  let rws' :: [TPrimExp Int64 VName]
rws' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
rws
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
dests forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
    let slice' :: Slice (TPrimExp Int64 VName)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
        write :: TExp Bool
write = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
rws'
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
write forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
e []
compileThreadResult SegSpace
_ PatElem LParamMem
_ TileReturns {} =
  forall a. [Char] -> a
compilerBugS [Char]
"compileThreadResult: TileReturns unhandled."