{-# LANGUAGE TypeFamilies #-}

-- | Compile a 'GPUMem' program to imperative code with kernels.
-- This is mostly (but not entirely) the same process no matter if we
-- are targeting OpenCL or CUDA.  The important distinctions (the host
-- level code) are introduced later.
module Futhark.CodeGen.ImpGen.GPU
  ( compileProgOpenCL,
    compileProgCUDA,
    Warnings,
  )
where

import Control.Monad
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen qualified
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegHist
import Futhark.CodeGen.ImpGen.GPU.SegMap
import Futhark.CodeGen.ImpGen.GPU.SegRed
import Futhark.CodeGen.ImpGen.GPU.SegScan
import Futhark.CodeGen.ImpGen.GPU.Transpose
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot, rem)
import Prelude hiding (quot, rem)

callKernelOperations :: Operations GPUMem HostEnv Imp.HostOp
callKernelOperations :: Operations GPUMem HostEnv HostOp
callKernelOperations =
  Operations
    { opsExpCompiler :: ExpCompiler GPUMem HostEnv HostOp
opsExpCompiler = ExpCompiler GPUMem HostEnv HostOp
expCompiler,
      opsCopyCompiler :: CopyCompiler GPUMem HostEnv HostOp
opsCopyCompiler = CopyCompiler GPUMem HostEnv HostOp
callKernelCopy,
      opsOpCompiler :: OpCompiler GPUMem HostEnv HostOp
opsOpCompiler = Pat LetDecMem -> Op GPUMem -> CallKernelGen ()
opCompiler,
      opsStmsCompiler :: StmsCompiler GPUMem HostEnv HostOp
opsStmsCompiler = forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem HostEnv HostOp)
opsAllocCompilers = forall a. Monoid a => a
mempty
    }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl, forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
cuda)
  where
    opencl64 :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl64 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int64),
        (IntType -> BinOp
SMax IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMax IntType
Int64),
        (IntType -> BinOp
SMin IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMin IntType
Int64),
        (IntType -> BinOp
UMax IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMax IntType
Int64),
        (IntType -> BinOp
UMin IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMin IntType
Int64),
        (IntType -> BinOp
And IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAnd IntType
Int64),
        (IntType -> BinOp
Or IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicOr IntType
Int64),
        (IntType -> BinOp
Xor IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicXor IntType
Int64)
      ]
    opencl32 :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl32 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32),
        (IntType -> BinOp
SMax IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMax IntType
Int32),
        (IntType -> BinOp
SMin IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMin IntType
Int32),
        (IntType -> BinOp
UMax IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMax IntType
Int32),
        (IntType -> BinOp
UMin IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMin IntType
Int32),
        (IntType -> BinOp
And IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAnd IntType
Int32),
        (IntType -> BinOp
Or IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicOr IntType
Int32),
        (IntType -> BinOp
Xor IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicXor IntType
Int32)
      ]
    opencl :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl = [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl32 forall a. [a] -> [a] -> [a]
++ [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl64
    cuda :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
cuda =
      [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl
        forall a. [a] -> [a] -> [a]
++ [ (FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicFAdd FloatType
Float32),
             (FloatType -> BinOp
FAdd FloatType
Float64, FloatType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicFAdd FloatType
Float64)
           ]

compileProg ::
  MonadFreshNames m =>
  HostEnv ->
  Prog GPUMem ->
  m (Warnings, Imp.Program)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg HostEnv
env =
  forall rep (inner :: * -> *) op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations GPUMem HostEnv HostOp
callKernelOperations Space
device_space
  where
    device_space :: Space
device_space = [Char] -> Space
Imp.Space [Char]
"device"

-- | Compile a 'GPUMem' program to low-level parallel code, with
-- either CUDA or OpenCL characteristics.
compileProgOpenCL,
  compileProgCUDA ::
    MonadFreshNames m => Prog GPUMem -> m (Warnings, Imp.Program)
compileProgOpenCL :: forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
compileProgOpenCL = forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
openclAtomics Target
OpenCL forall a. Monoid a => a
mempty
compileProgCUDA :: forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
compileProgCUDA = forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
cudaAtomics Target
CUDA forall a. Monoid a => a
mempty

opCompiler ::
  Pat LetDecMem ->
  Op GPUMem ->
  CallKernelGen ()
opCompiler :: Pat LetDecMem -> Op GPUMem -> CallKernelGen ()
opCompiler Pat LetDecMem
dest (Alloc SubExp
e Space
space) =
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc Pat LetDecMem
dest SubExp
e Space
space
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (GetSize Name
key SizeClass
size_class))) = do
  Maybe Name
fname <- forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x))) = do
  Maybe Name
fname <- forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  forall op rep r. op -> ImpM rep r op ()
sOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (GetSizeMax SizeClass
size_class))) =
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) SizeClass
size_class
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (CalcNumGroups SubExp
w64 Name
max_num_groups_key SubExp
group_size))) = do
  Maybe Name
fname <- forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  TV Int32
max_num_groups :: TV Int32 <- forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"max_num_groups" PrimType
int32
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
max_num_groups) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: TPrimExp Int64 VName
num_groups_maybe_zero =
        forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w64 forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size) forall a b. (a -> b) -> a -> b
$
          forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: TPrimExp Int64 VName
num_groups = forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
num_groups_maybe_zero
  forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) PrimType
int32 forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
num_groups
opCompiler Pat LetDecMem
dest (Inner (SegOp SegOp SegLevel GPUMem
op)) =
  Pat LetDecMem -> SegOp SegLevel GPUMem -> CallKernelGen ()
segOpCompiler Pat LetDecMem
dest SegOp SegLevel GPUMem
op
opCompiler (Pat [PatElem LetDecMem]
pes) (Inner (GPUBody [TypeBase (ShapeBase SubExp) NoUniqueness]
_ (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res))) = do
  VName
tid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tid"
  let one :: Count u SubExp
one = forall {k} (u :: k) e. e -> Count u e
Count (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"gpuseq" VName
tid (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs forall {k} {u :: k}. Count u SubExp
one forall {k} {u :: k}. Count u SubExp
one) forall a b. (a -> b) -> a -> b
$
    forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
res) Stms GPUMem
stms forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
pes Result
res) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, SubExpRes Certs
_ SubExp
se) ->
        forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0] SubExp
se []
opCompiler Pat LetDecMem
pat Op GPUMem
e =
  forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$
    [Char]
"opCompiler: Invalid pattern\n  "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat LetDecMem
pat
      forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Op GPUMem
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path Maybe Int64
def) =
  KernelPath -> Maybe Int64 -> SizeClass
Imp.SizeThreshold (forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path) Maybe Int64
def
  where
    f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler ::
  Pat LetDecMem ->
  SegOp SegLevel GPUMem ->
  CallKernelGen ()
segOpCompiler :: Pat LetDecMem -> SegOp SegLevel GPUMem -> CallKernelGen ()
segOpCompiler Pat LetDecMem
pat (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat LetDecMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegRed lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat LetDecMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegScan lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LetDecMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegHist lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist Pat LetDecMem
pat SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat SegOp SegLevel GPUMem
segop =
  forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$ [Char]
"segOpCompiler: unexpected " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
segop) forall a. [a] -> [a] -> [a]
++ [Char]
" for rhs of pattern " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat LetDecMem
pat

-- Create boolean expression that checks whether all kernels in the
-- enclosed code do not use more local memory than we have available.
-- We look at *all* the kernels here, even those that might be
-- otherwise protected by their own multi-versioning branches deeper
-- down.  Currently the compiler will not generate multi-versioning
-- that makes this a problem, but it might in the future.
checkLocalMemoryReqs :: Imp.HostCode -> CallKernelGen (Maybe (Imp.TExp Bool))
checkLocalMemoryReqs :: Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
code = do
  Scope SOACS
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let alloc_sizes :: [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes = forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {a}. IntegralExp a => a -> a
alignedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSizes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> Code KernelOp
Imp.kernelBody) forall a b. (a -> b) -> a -> b
$ Code HostOp -> [Kernel]
getGPU Code HostOp
code

  -- If any of the sizes involve a variable that is not known at this
  -- point, then we cannot check the requirements.
  if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` Scope SOACS
scope) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes)
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    else do
      TV Int32
local_memory_capacity :: TV Int32 <- forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"local_memory_capacity" PrimType
int32
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
local_memory_capacity) SizeClass
SizeLocalMemory

      let local_memory_capacity_64 :: TPrimExp Int64 VName
local_memory_capacity_64 =
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
local_memory_capacity
          fits :: Count Bytes (TPrimExp Int64 VName) -> TExp Bool
fits Count Bytes (TPrimExp Int64 VName)
size =
            forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes (TPrimExp Int64 VName)
size forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
local_memory_capacity_64
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall v. TPrimExp Bool v
true (forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TPrimExp Int64 VName) -> TExp Bool
fits [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes)
  where
    getGPU :: Code HostOp -> [Kernel]
getGPU = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HostOp -> [Kernel]
getKernel
    getKernel :: HostOp -> [Kernel]
getKernel (Imp.CallKernel Kernel
k) | Kernel -> Bool
Imp.kernelCheckLocalMemory Kernel
k = [Kernel
k]
    getKernel HostOp
_ = []

    localAllocSizes :: Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSizes = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSize
    localAllocSize :: KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSize (Imp.LocalAlloc VName
_ Count Bytes (TPrimExp Int64 VName)
size) = [Count Bytes (TPrimExp Int64 VName)
size]
    localAllocSize KernelOp
_ = []

    -- These allocations will actually be padded to an 8-byte aligned
    -- size, so we should take that into account when checking whether
    -- they fit.
    alignedSize :: a -> a
alignedSize a
x = a
x forall a. Num a => a -> a -> a
+ ((a
8 forall a. Num a => a -> a -> a
- (a
x forall e. IntegralExp e => e -> e -> e
`rem` a
8)) forall e. IntegralExp e => e -> e -> e
`rem` a
8)

withAcc ::
  Pat LetDecMem ->
  [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))] ->
  Lambda GPUMem ->
  CallKernelGen ()
withAcc :: Pat LetDecMem
-> [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> CallKernelGen ()
withAcc Pat LetDecMem
pat [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  AtomicBinOp
-> [(VName,
     (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accs [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs
  where
    accs :: [VName]
accs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
    locksForInputs :: AtomicBinOp
-> [(VName,
     (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
_ [] =
      forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat LetDecMem
pat forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
    locksForInputs AtomicBinOp
atomics ((VName
c, (ShapeBase SubExp
_, [VName]
_, Maybe (Lambda GPUMem, [SubExp])
op)) : [(VName,
  (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs')
      | Just (Lambda GPUMem
op_lam, [SubExp]
_) <- Maybe (Lambda GPUMem, [SubExp])
op,
        AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ <- AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
op_lam = do
          let num_locks :: Int
num_locks = Int
100151
          VName
locks_arr <- [Char] -> Int -> CallKernelGen VName
genZeroes [Char]
"withacc_locks" Int
num_locks
          let locks :: Locks
locks = VName -> Int -> Locks
Locks VName
locks_arr Int
num_locks
              extend :: HostEnv -> HostEnv
extend HostEnv
env = HostEnv
env {hostLocks :: Map VName Locks
hostLocks = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
c Locks
locks forall a b. (a -> b) -> a -> b
$ HostEnv -> Map VName Locks
hostLocks HostEnv
env}
          forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv HostEnv -> HostEnv
extend forall a b. (a -> b) -> a -> b
$ AtomicBinOp
-> [(VName,
     (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics [(VName,
  (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'
      | Bool
otherwise =
          AtomicBinOp
-> [(VName,
     (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics [(VName,
  (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'

expCompiler :: ExpCompiler GPUMem HostEnv Imp.HostOp
-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler GPUMem HostEnv HostOp
expCompiler (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
et)) = do
  Exp
x' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
s' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s

  VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIota (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n) Exp
x' Exp
s' IntType
et
expCompiler (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Replicate ShapeBase SubExp
_ SubExp
se))
  | Acc {} <- forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise =
      VName -> SubExp -> CallKernelGen ()
sReplicate (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) SubExp
se
-- Allocation in the "local" space is just a placeholder.
expCompiler Pat (LetDec GPUMem)
_ (Op (Alloc SubExp
_ (Space [Char]
"local"))) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
expCompiler Pat (LetDec GPUMem)
pat (WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam) =
  Pat LetDecMem
-> [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> CallKernelGen ()
withAcc Pat (LetDec GPUMem)
pat [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
-- This is a multi-versioning Match created by incremental flattening.
-- We need to augment the conditional with a check that any local
-- memory requirements in tbranch are compatible with the hardware.
-- We do not check anything for defbody, as we assume that it will
-- always be safe (and what would we do if none of the branches would
-- work?).
expCompiler Pat (LetDec GPUMem)
dest (Match [SubExp]
cond (Case (Body GPUMem)
first_case : [Case (Body GPUMem)]
cases) Body GPUMem
defbranch sort :: MatchDec (BranchType GPUMem)
sort@(MatchDec [BranchType GPUMem]
_ MatchSort
MatchEquiv)) = do
  Code HostOp
tcode <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec GPUMem)
dest forall a b. (a -> b) -> a -> b
$ forall body. Case body -> body
caseBody Case (Body GPUMem)
first_case
  Code HostOp
fcode <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ ExpCompiler GPUMem HostEnv HostOp
expCompiler Pat (LetDec GPUMem)
dest forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbranch MatchDec (BranchType GPUMem)
sort
  Maybe (TExp Bool)
check <- Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
tcode
  let matches :: TExp Bool
matches = [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
cond (forall body. Case body -> [Maybe PrimValue]
casePat Case (Body GPUMem)
first_case)
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ case Maybe (TExp Bool)
check of
    Maybe (TExp Bool)
Nothing -> Code HostOp
fcode
    Just TExp Bool
ok -> forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If (TExp Bool
matches forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ok) Code HostOp
tcode Code HostOp
fcode
expCompiler Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

callKernelCopy :: CopyCompiler GPUMem HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler GPUMem HostEnv HostOp
callKernelCopy PrimType
bt destloc :: MemLoc
destloc@(MemLoc VName
destmem [SubExp]
_ IxFun (TPrimExp Int64 VName)
destIxFun) srcloc :: MemLoc
srcloc@(MemLoc VName
srcmem [SubExp]
srcshape IxFun (TPrimExp Int64 VName)
srcIxFun)
  | Just (TPrimExp Int64 VName
destoffset, TPrimExp Int64 VName
srcoffset, TPrimExp Int64 VName
num_arrays, TPrimExp Int64 VName
size_x, TPrimExp Int64 VName
size_y) <-
      PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TPrimExp Int64 VName, TPrimExp Int64 VName, TPrimExp Int64 VName,
      TPrimExp Int64 VName, TPrimExp Int64 VName)
isMapTransposeCopy PrimType
bt MemLoc
destloc MemLoc
srcloc = do
      Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
        forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
          [ VName -> Arg
Imp.MemArg VName
destmem,
            Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
destoffset,
            VName -> Arg
Imp.MemArg VName
srcmem,
            Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
srcoffset,
            Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_arrays,
            Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
size_x,
            Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
size_y
          ]
  | TPrimExp Int64 VName
bt_size <- forall a. Num a => PrimType -> a
primByteSize PrimType
bt,
    Just TPrimExp Int64 VName
destoffset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TPrimExp Int64 VName)
destIxFun TPrimExp Int64 VName
bt_size,
    Just TPrimExp Int64 VName
srcoffset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TPrimExp Int64 VName)
srcIxFun TPrimExp Int64 VName
bt_size = do
      let num_elems :: Count Elements (TPrimExp Int64 VName)
num_elems = forall a. a -> Count Elements a
Imp.elements forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
srcshape
      Space
srcspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
      Space
destspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
      forall rep r op.
VName
-> TPrimExp Int64 VName
-> Space
-> VName
-> TPrimExp Int64 VName
-> Space
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> ImpM rep r op ()
sCopy
        VName
destmem
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
destoffset)
        Space
destspace
        VName
srcmem
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
srcoffset)
        Space
srcspace
        Count Elements (TPrimExp Int64 VName)
num_elems
        PrimType
bt
  | Bool
otherwise = CopyCompiler GPUMem HostEnv HostOp
sCopyKernel PrimType
bt MemLoc
destloc MemLoc
srcloc

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"gpu_map_transpose_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function Imp.HostOp
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function forall a. Maybe a
Nothing [] [Param]
params forall a b. (a -> b) -> a -> b
$
    forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint ([Char]
"\n# Transpose " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString PrimType
bt) forall a. Maybe a
Nothing
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of arrays  " (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays)
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"X elements        " (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x)
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Y elements        " (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y)
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Source      offset" (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
srcoffset)
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Destination offset" (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
destoffset)
      forall a. Semigroup a => a -> a -> a
<> Code HostOp
transpose_code
      forall a. Semigroup a => a -> a -> a
<> forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" forall a. Maybe a
Nothing
  where
    params :: [Param]
params =
      [ VName -> Param
memparam VName
destmem,
        VName -> Param
intparam VName
destoffset,
        VName -> Param
memparam VName
srcmem,
        VName -> Param
intparam VName
srcoffset,
        VName -> Param
intparam VName
num_arrays,
        VName -> Param
intparam VName
x,
        VName -> Param
intparam VName
y
      ]

    space :: Space
space = [Char] -> Space
Space [Char]
"device"
    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64

    [ VName
destmem,
      VName
destoffset,
      VName
srcmem,
      VName
srcoffset,
      VName
num_arrays,
      VName
x,
      VName
y,
      VName
mulx,
      VName
muly,
      VName
block,
      VName
use_32b
      ] =
        forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (Name -> Int -> VName
VName forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Name
nameFromString)
          [ [Char]
"destmem",
            [Char]
"destoffset",
            [Char]
"srcmem",
            [Char]
"srcoffset",
            [Char]
"num_arrays",
            [Char]
"x_elems",
            [Char]
"y_elems",
            -- The following is only used for low width/height
            -- transpose kernels
            [Char]
"mulx",
            [Char]
"muly",
            [Char]
"block",
            [Char]
"use_32b"
          ]
          [Int
0 ..]

    block_dim_int :: Integer
block_dim_int = Integer
16

    block_dim :: IntegralExp a => a
    block_dim :: forall a. IntegralExp a => a
block_dim = a
16

    -- When an input array has either width==1 or height==1, performing a
    -- transpose will be the same as performing a copy.
    can_use_copy :: TExp Bool
can_use_copy =
      let onearr :: TExp Bool
onearr = forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
          height_is_one :: TExp Bool
height_is_one = forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
          width_is_one :: TExp Bool
width_is_one = forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
       in TExp Bool
onearr forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Bool
width_is_one forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
height_is_one)

    transpose_code :: Code HostOp
transpose_code =
      forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
input_is_empty forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat
          [ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int64),
            forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x,
            forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int64),
            forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y,
            forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
use_32b Volatility
Imp.Nonvolatile PrimType
Bool,
            forall a. VName -> Exp -> Code a
Imp.SetScalar VName
use_32b forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
                (forall a. a -> TPrimExp Int64 a
le64 VName
destoffset forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
x forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
y) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
2 forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
31 :: Int) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                  forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
x forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
y) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
2 forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
31 :: Int) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1,
            forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
can_use_copy Code HostOp
copy_code forall a b. (a -> b) -> a -> b
$
              forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowwidth (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowWidth) forall a b. (a -> b) -> a -> b
$
                forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowheight (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowHeight) forall a b. (a -> b) -> a -> b
$
                  forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_small (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeSmall) forall a b. (a -> b) -> a -> b
$
                    TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeNormal
          ]

    input_is_empty :: TExp Bool
input_is_empty =
      forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0

    should_use_small :: TExp Bool
should_use_small =
      forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)

    should_use_lowwidth :: TExp Bool
should_use_lowwidth =
      forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. IntegralExp a => a
block_dim forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y

    should_use_lowheight :: TExp Bool
should_use_lowheight =
      forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (forall a. IntegralExp a => a
block_dim forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. IntegralExp a => a
block_dim forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x

    copy_code :: Code HostOp
copy_code =
      let num_bytes :: TPrimExp Int64 VName
num_bytes = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y forall a. Num a => a -> a -> a
* forall a. Num a => PrimType -> a
primByteSize PrimType
bt
       in forall a.
PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Bytes (TPrimExp Int64 VName)
-> Code a
Imp.Copy
            PrimType
bt
            VName
destmem
            (forall {k} (u :: k) e. e -> Count u e
Imp.Count forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
destoffset)
            Space
space
            VName
srcmem
            (forall {k} (u :: k) e. e -> Count u e
Imp.Count forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
srcoffset)
            Space
space
            (forall {k} (u :: k) e. e -> Count u e
Imp.Count TPrimExp Int64 VName
num_bytes)

    callTransposeKernel :: TransposeType -> Code HostOp
callTransposeKernel TransposeType
which =
      forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If
        (forall v. PrimExp v -> TPrimExp Bool v
isBool (forall v. v -> PrimType -> PrimExp v
LeafExp VName
use_32b PrimType
Bool))
        ( forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Using 32-bit indexing" forall a. Maybe a
Nothing
            forall a. Semigroup a => a -> a -> a
<> TransposeType -> Code HostOp
callTransposeKernel32 TransposeType
which
        )
        ( forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Using 64-bit indexing" forall a. Maybe a
Nothing
            forall a. Semigroup a => a -> a -> a
<> TransposeType -> Code HostOp
callTransposeKernel64 TransposeType
which
        )

    callTransposeKernel64 :: TransposeType -> Code HostOp
callTransposeKernel64 =
      forall a. a -> Code a
Imp.Op
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> [Char]
-> Integer
-> TransposeArgs int
-> PrimType
-> TransposeType
-> Kernel
mapTransposeKernel
          (PrimType
int64, forall a. a -> TPrimExp Int64 a
le64)
          (PrimType -> [Char]
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            forall a. a -> TPrimExp Int64 a
le64 VName
destoffset,
            VName
srcmem,
            forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset,
            forall a. a -> TPrimExp Int64 a
le64 VName
x,
            forall a. a -> TPrimExp Int64 a
le64 VName
y,
            forall a. a -> TPrimExp Int64 a
le64 VName
mulx,
            forall a. a -> TPrimExp Int64 a
le64 VName
muly,
            forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays,
            VName
block
          )
          PrimType
bt

    callTransposeKernel32 :: TransposeType -> Code HostOp
callTransposeKernel32 =
      forall a. a -> Code a
Imp.Op
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> [Char]
-> Integer
-> TransposeArgs int
-> PrimType
-> TransposeType
-> Kernel
mapTransposeKernel
          (PrimType
int32, forall a. a -> TPrimExp Int32 a
le32)
          (PrimType -> [Char]
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
destoffset),
            VName
srcmem,
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset),
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
x),
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
y),
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
mulx),
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
muly),
            forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays),
            VName
block
          )
          PrimType
bt

-- Note [32-bit transpositions]
--
-- Transposition kernels are much slower when they have to use 64-bit
-- arithmetic.  I observed about 0.67x slowdown on an A100 GPU when
-- transposing four-byte elements (much less when transposing 8-byte
-- elements).  Unfortunately, 64-bit arithmetic is a requirement for
-- large arrays (see #1953 for what happens otherwise).  We generate
-- both 32- and 64-bit index arithmetic versions of transpositions,
-- and dynamically pick between them at runtime.  This is an
-- unfortunate code bloat, and it would be preferable if we could
-- simply optimise the 64-bit version to make this distinction
-- unnecessary.  Fortunately these kernels are quite small.