{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | Compile a 'GPUMem' program to imperative code with kernels.
-- This is mostly (but not entirely) the same process no matter if we
-- are targeting OpenCL or CUDA.  The important distinctions (the host
-- level code) are introduced later.
module Futhark.CodeGen.ImpGen.GPU
  ( compileProgOpenCL,
    compileProgCUDA,
    Warnings,
  )
where

import Control.Monad.Except
import Data.Bifunctor (second)
import Data.List (foldl')
import qualified Data.Map as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU (bytes)
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import qualified Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegHist
import Futhark.CodeGen.ImpGen.GPU.SegMap
import Futhark.CodeGen.ImpGen.GPU.SegRed
import Futhark.CodeGen.ImpGen.GPU.SegScan
import Futhark.CodeGen.ImpGen.GPU.Transpose
import Futhark.CodeGen.SetDefaultSpace
import Futhark.Error
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot, rem)
import Prelude hiding (quot, rem)

callKernelOperations :: Operations GPUMem HostEnv Imp.HostOp
callKernelOperations :: Operations GPUMem HostEnv HostOp
callKernelOperations =
  Operations :: forall rep r op.
ExpCompiler rep r op
-> OpCompiler rep r op
-> StmsCompiler rep r op
-> CopyCompiler rep r op
-> Map Space (AllocCompiler rep r op)
-> Operations rep r op
Operations
    { opsExpCompiler :: ExpCompiler GPUMem HostEnv HostOp
opsExpCompiler = ExpCompiler GPUMem HostEnv HostOp
expCompiler,
      opsCopyCompiler :: CopyCompiler GPUMem HostEnv HostOp
opsCopyCompiler = CopyCompiler GPUMem HostEnv HostOp
callKernelCopy,
      opsOpCompiler :: OpCompiler GPUMem HostEnv HostOp
opsOpCompiler = OpCompiler GPUMem HostEnv HostOp
opCompiler,
      opsStmsCompiler :: StmsCompiler GPUMem HostEnv HostOp
opsStmsCompiler = StmsCompiler GPUMem HostEnv HostOp
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem HostEnv HostOp)
opsAllocCompilers = Map Space (AllocCompiler GPUMem HostEnv HostOp)
forall a. Monoid a => a
mempty
    }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = ((BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
 -> Maybe
      (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> Maybe
     (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl, (BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
 -> Maybe
      (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> Maybe
     (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
cuda)
  where
    opencl64 :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl64 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int64),
        (IntType -> BinOp
SMax IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int64),
        (IntType -> BinOp
SMin IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int64),
        (IntType -> BinOp
UMax IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int64),
        (IntType -> BinOp
UMin IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int64),
        (IntType -> BinOp
And IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int64),
        (IntType -> BinOp
Or IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int64),
        (IntType -> BinOp
Xor IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int64)
      ]
    opencl32 :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl32 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32),
        (IntType -> BinOp
SMax IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int32),
        (IntType -> BinOp
SMin IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int32),
        (IntType -> BinOp
UMax IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int32),
        (IntType -> BinOp
UMin IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int32),
        (IntType -> BinOp
And IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int32),
        (IntType -> BinOp
Or IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int32),
        (IntType -> BinOp
Xor IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int32)
      ]
    opencl :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl = [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl32 [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl64
    cuda :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
cuda =
      [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl
        [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [ (FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float32),
             (FloatType -> BinOp
FAdd FloatType
Float64, FloatType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float64)
           ]

compileProg ::
  MonadFreshNames m =>
  HostEnv ->
  Prog GPUMem ->
  m (Warnings, Imp.Program)
compileProg :: HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg HostEnv
env Prog GPUMem
prog =
  (Program -> Program) -> (Warnings, Program) -> (Warnings, Program)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((HostOp -> HostOp) -> Program -> Program
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> HostOp
setOpSpace (Program -> Program) -> (Program -> Program) -> Program -> Program
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Program -> Program
forall op. Definitions op -> Definitions op
setDefsSpace)
    ((Warnings, Program) -> (Warnings, Program))
-> m (Warnings, Program) -> m (Warnings, Program)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HostEnv
-> Operations GPUMem HostEnv HostOp
-> Space
-> Prog GPUMem
-> m (Warnings, Program)
forall rep inner op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations GPUMem HostEnv HostOp
callKernelOperations Space
device_space Prog GPUMem
prog
  where
    device_space :: Space
device_space = SpaceId -> Space
Imp.Space SpaceId
"device"
    global_space :: Space
global_space = SpaceId -> Space
Imp.Space SpaceId
"global"
    setDefsSpace :: Definitions op -> Definitions op
setDefsSpace = Space -> Definitions op -> Definitions op
forall op. Space -> Definitions op -> Definitions op
setDefaultSpace Space
device_space
    setOpSpace :: HostOp -> HostOp
setOpSpace (Imp.CallKernel Kernel
kernel) =
      Kernel -> HostOp
Imp.CallKernel
        Kernel
kernel
          { kernelBody :: Code KernelOp
Imp.kernelBody =
              Space -> Code KernelOp -> Code KernelOp
forall op. Space -> Code op -> Code op
setDefaultCodeSpace Space
global_space (Code KernelOp -> Code KernelOp) -> Code KernelOp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
Imp.kernelBody Kernel
kernel
          }
    setOpSpace HostOp
op = HostOp
op

-- | Compile a 'GPUMem' program to low-level parallel code, with
-- either CUDA or OpenCL characteristics.
compileProgOpenCL,
  compileProgCUDA ::
    MonadFreshNames m => Prog GPUMem -> m (Warnings, Imp.Program)
compileProgOpenCL :: Prog GPUMem -> m (Warnings, Program)
compileProgOpenCL = HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog GPUMem -> m (Warnings, Program))
-> HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
openclAtomics Target
OpenCL Map VName Locks
forall a. Monoid a => a
mempty
compileProgCUDA :: Prog GPUMem -> m (Warnings, Program)
compileProgCUDA = HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog GPUMem -> m (Warnings, Program))
-> HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
cudaAtomics Target
CUDA Map VName Locks
forall a. Monoid a => a
mempty

opCompiler ::
  Pat GPUMem ->
  Op GPUMem ->
  CallKernelGen ()
opCompiler :: OpCompiler GPUMem HostEnv HostOp
opCompiler Pat GPUMem
dest (Alloc e space) =
  Pat GPUMem -> SubExp -> Space -> ImpM GPUMem HostEnv HostOp ()
forall rep inner r op.
Mem rep inner =>
Pat rep -> SubExp -> Space -> ImpM rep r op ()
compileAlloc Pat GPUMem
dest SubExp
e Space
space
opCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (Inner (SizeOp (GetSize key size_class))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
opCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (Inner (SizeOp (CmpSizeLe key size_class x))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> (Exp -> HostOp) -> Exp -> ImpM GPUMem HostEnv HostOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    (Exp -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp Exp -> ImpM GPUMem HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
opCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (Inner (SizeOp (GetSizeMax size_class))) =
  HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) SizeClass
size_class
opCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (Inner (SizeOp (CalcNumGroups w64 max_num_groups_key group_size))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  TV Int32
max_num_groups :: TV Int32 <- SpaceId -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall rep r op t. SpaceId -> PrimType -> ImpM rep r op (TV t)
dPrim SpaceId
"max_num_groups" PrimType
int32
  HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
max_num_groups) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: TExp Int64
num_groups_maybe_zero =
        TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w64 TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
group_size) (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: TExp Int64
num_groups = TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
num_groups_maybe_zero
  VName -> PrimType -> TV Int32
forall t. VName -> PrimType -> TV t
mkTV (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) PrimType
int32 TV Int32 -> TPrimExp Int32 ExpLeaf -> ImpM GPUMem HostEnv HostOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
num_groups
opCompiler Pat GPUMem
dest (Inner (SegOp op)) =
  Pat GPUMem
-> SegOp SegLevel GPUMem -> ImpM GPUMem HostEnv HostOp ()
segOpCompiler Pat GPUMem
dest SegOp SegLevel GPUMem
op
opCompiler Pat GPUMem
pat Op GPUMem
e =
  SpaceId -> ImpM GPUMem HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM GPUMem HostEnv HostOp ())
-> SpaceId -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    SpaceId
"opCompiler: Invalid pattern\n  "
      SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PatT LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pat GPUMem
PatT LetDecMem
pat
      SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"\nfor expression\n  "
      SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ MemOp (HostOp GPUMem ()) -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Op GPUMem
MemOp (HostOp GPUMem ())
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path Maybe Int64
def) =
  KernelPath -> Maybe Int64 -> SizeClass
Imp.SizeThreshold (((Name, Bool) -> (Name, Bool)) -> KernelPath -> KernelPath
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path) Maybe Int64
def
  where
    f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler ::
  Pat GPUMem ->
  SegOp SegLevel GPUMem ->
  CallKernelGen ()
segOpCompiler :: Pat GPUMem
-> SegOp SegLevel GPUMem -> ImpM GPUMem HostEnv HostOp ()
segOpCompiler Pat GPUMem
pat (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody GPUMem
kbody) =
  Pat GPUMem
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> ImpM GPUMem HostEnv HostOp ()
compileSegMap Pat GPUMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody
segOpCompiler Pat GPUMem
pat (SegRed lvl :: SegLevel
lvl@SegThread {} SegSpace
space [SegBinOp GPUMem]
reds [Type]
_ KernelBody GPUMem
kbody) =
  Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> ImpM GPUMem HostEnv HostOp ()
compileSegRed Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
kbody
segOpCompiler Pat GPUMem
pat (SegScan lvl :: SegLevel
lvl@SegThread {} SegSpace
space [SegBinOp GPUMem]
scans [Type]
_ KernelBody GPUMem
kbody) =
  Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> ImpM GPUMem HostEnv HostOp ()
compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
segOpCompiler Pat GPUMem
pat (SegHist (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) SegSpace
space [HistOp GPUMem]
ops [Type]
_ KernelBody GPUMem
kbody) =
  Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> ImpM GPUMem HostEnv HostOp ()
compileSegHist Pat GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody
segOpCompiler Pat GPUMem
pat SegOp SegLevel GPUMem
segop =
  SpaceId -> ImpM GPUMem HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM GPUMem HostEnv HostOp ())
-> SpaceId -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"segOpCompiler: unexpected " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SegLevel -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
segop) SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" for rhs of pattern " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PatT LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pat GPUMem
PatT LetDecMem
pat

-- Create boolean expression that checks whether all kernels in the
-- enclosed code do not use more local memory than we have available.
-- We look at *all* the kernels here, even those that might be
-- otherwise protected by their own multi-versioning branches deeper
-- down.  Currently the compiler will not generate multi-versioning
-- that makes this a problem, but it might in the future.
checkLocalMemoryReqs :: Imp.Code -> CallKernelGen (Maybe (Imp.TExp Bool))
checkLocalMemoryReqs :: Code -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code
code = do
  Scope SOACS
scope <- ImpM GPUMem HostEnv HostOp (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let alloc_sizes :: [Count Bytes (TExp Int64)]
alloc_sizes = (Kernel -> Count Bytes (TExp Int64))
-> [Kernel] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> (Kernel -> [Count Bytes (TExp Int64)])
-> Kernel
-> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Count Bytes (TExp Int64) -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall e. IntegralExp e => e -> e
alignedSize ([Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)])
-> (Kernel -> [Count Bytes (TExp Int64)])
-> Kernel
-> [Count Bytes (TExp Int64)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Code KernelOp -> [Count Bytes (TExp Int64)]
localAllocSizes (Code KernelOp -> [Count Bytes (TExp Int64)])
-> (Kernel -> Code KernelOp)
-> Kernel
-> [Count Bytes (TExp Int64)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> Code KernelOp
Imp.kernelBody) ([Kernel] -> [Count Bytes (TExp Int64)])
-> [Kernel] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ Code -> [Kernel]
getGPU Code
code

  -- If any of the sizes involve a variable that is not known at this
  -- point, then we cannot check the requirements.
  if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Scope SOACS -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` Scope SOACS
scope) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Names
forall a. FreeIn a => a -> Names
freeIn [Count Bytes (TExp Int64)]
alloc_sizes)
    then Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (TExp Bool)
forall a. Maybe a
Nothing
    else do
      TV Int32
local_memory_capacity :: TV Int32 <- SpaceId -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall rep r op t. SpaceId -> PrimType -> ImpM rep r op (TV t)
dPrim SpaceId
"local_memory_capacity" PrimType
int32
      HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
local_memory_capacity) SizeClass
SizeLocalMemory

      let local_memory_capacity_64 :: TExp Int64
local_memory_capacity_64 =
            TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
local_memory_capacity
          fits :: Count Bytes (TExp Int64) -> TExp Bool
fits Count Bytes (TExp Int64)
size =
            Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count Bytes (TExp Int64)
size TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
local_memory_capacity_64
      Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool)))
-> Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Maybe (TExp Bool)
forall a. a -> Maybe a
Just (TExp Bool -> Maybe (TExp Bool)) -> TExp Bool -> Maybe (TExp Bool)
forall a b. (a -> b) -> a -> b
$ (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
forall v. TPrimExp Bool v
true ((Count Bytes (TExp Int64) -> TExp Bool)
-> [Count Bytes (TExp Int64)] -> [TExp Bool]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TExp Int64) -> TExp Bool
fits [Count Bytes (TExp Int64)]
alloc_sizes)
  where
    getGPU :: Code -> [Kernel]
getGPU = (HostOp -> [Kernel]) -> Code -> [Kernel]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HostOp -> [Kernel]
getKernel
    getKernel :: HostOp -> [Kernel]
getKernel (Imp.CallKernel Kernel
k) = [Kernel
k]
    getKernel HostOp
_ = []

    localAllocSizes :: Code KernelOp -> [Count Bytes (TExp Int64)]
localAllocSizes = (KernelOp -> [Count Bytes (TExp Int64)])
-> Code KernelOp -> [Count Bytes (TExp Int64)]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelOp -> [Count Bytes (TExp Int64)]
localAllocSize
    localAllocSize :: KernelOp -> [Count Bytes (TExp Int64)]
localAllocSize (Imp.LocalAlloc VName
_ Count Bytes (TExp Int64)
size) = [Count Bytes (TExp Int64)
size]
    localAllocSize KernelOp
_ = []

    -- These allocations will actually be padded to an 8-byte aligned
    -- size, so we should take that into account when checking whether
    -- they fit.
    alignedSize :: e -> e
alignedSize e
x = e
x e -> e -> e
forall a. Num a => a -> a -> a
+ ((e
8 e -> e -> e
forall a. Num a => a -> a -> a
- (e
x e -> e -> e
forall e. IntegralExp e => e -> e -> e
`rem` e
8)) e -> e -> e
forall e. IntegralExp e => e -> e -> e
`rem` e
8)

withAcc ::
  Pat GPUMem ->
  [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))] ->
  Lambda GPUMem ->
  CallKernelGen ()
withAcc :: Pat GPUMem
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> ImpM GPUMem HostEnv HostOp ()
withAcc Pat GPUMem
pat [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ImpM GPUMem HostEnv HostOp ()
locksForInputs AtomicBinOp
atomics ([(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
 -> ImpM GPUMem HostEnv HostOp ())
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accs [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs
  where
    accs :: [VName]
accs = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
    locksForInputs :: AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ImpM GPUMem HostEnv HostOp ()
locksForInputs AtomicBinOp
_ [] =
      ExpCompiler GPUMem HostEnv HostOp
forall rep inner r op.
Mem rep inner =>
Pat rep -> Exp rep -> ImpM rep r op ()
defCompileExp Pat GPUMem
pat (Exp GPUMem -> ImpM GPUMem HostEnv HostOp ())
-> Exp GPUMem -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem -> Exp GPUMem
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
    locksForInputs AtomicBinOp
atomics ((VName
c, (Shape
_, [VName]
_, Maybe (Lambda GPUMem, [SubExp])
op)) : [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs')
      | Just (Lambda GPUMem
op_lam, [SubExp]
_) <- Maybe (Lambda GPUMem, [SubExp])
op,
        AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ <- AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
op_lam = do
        let num_locks :: Int
num_locks = Int
100151
        VName
locks_arr <-
          SpaceId
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
SpaceId
-> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray SpaceId
"withacc_locks" (SpaceId -> Space
Space SpaceId
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
            Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
        let locks :: Locks
locks = VName -> Int -> Locks
Locks VName
locks_arr Int
num_locks
            extend :: HostEnv -> HostEnv
extend HostEnv
env = HostEnv
env {hostLocks :: Map VName Locks
hostLocks = VName -> Locks -> Map VName Locks -> Map VName Locks
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
c Locks
locks (Map VName Locks -> Map VName Locks)
-> Map VName Locks -> Map VName Locks
forall a b. (a -> b) -> a -> b
$ HostEnv -> Map VName Locks
hostLocks HostEnv
env}
        (HostEnv -> HostEnv)
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv HostEnv -> HostEnv
extend (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ImpM GPUMem HostEnv HostOp ()
locksForInputs AtomicBinOp
atomics [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'
      | Bool
otherwise =
        AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ImpM GPUMem HostEnv HostOp ()
locksForInputs AtomicBinOp
atomics [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'

expCompiler :: ExpCompiler GPUMem HostEnv Imp.HostOp
-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler GPUMem HostEnv HostOp
expCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
et)) = do
  Exp
x' <- SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
s' <- SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s

  VName
-> TExp Int64
-> Exp
-> Exp
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIota (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
n) Exp
x' Exp
s' IntType
et
expCompiler (Pat [PatElemT (LetDec GPUMem)
pe]) (BasicOp (Replicate Shape
_ SubExp
se)) =
  VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicate (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LetDecMem
pe) SubExp
se
-- Allocation in the "local" space is just a placeholder.
expCompiler Pat GPUMem
_ (Op (Alloc _ (Space "local"))) =
  () -> ImpM GPUMem HostEnv HostOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
expCompiler Pat GPUMem
pat (WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam) =
  Pat GPUMem
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> ImpM GPUMem HostEnv HostOp ()
withAcc Pat GPUMem
pat [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
-- This is a multi-versioning If created by incremental flattening.
-- We need to augment the conditional with a check that any local
-- memory requirements in tbranch are compatible with the hardware.
-- We do not check anything for fbranch, as we assume that it will
-- always be safe (and what would we do if none of the branches would
-- work?).
expCompiler Pat GPUMem
dest (If SubExp
cond BodyT GPUMem
tbranch BodyT GPUMem
fbranch (IfDec [BranchType GPUMem]
_ IfSort
IfEquiv)) = do
  Code
tcode <- ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code)
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code
forall a b. (a -> b) -> a -> b
$ Pat GPUMem -> BodyT GPUMem -> ImpM GPUMem HostEnv HostOp ()
forall rep r op. Pat rep -> Body rep -> ImpM rep r op ()
compileBody Pat GPUMem
dest BodyT GPUMem
tbranch
  Code
fcode <- ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code)
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp Code
forall a b. (a -> b) -> a -> b
$ Pat GPUMem -> BodyT GPUMem -> ImpM GPUMem HostEnv HostOp ()
forall rep r op. Pat rep -> Body rep -> ImpM rep r op ()
compileBody Pat GPUMem
dest BodyT GPUMem
fbranch
  Maybe (TExp Bool)
check <- Code -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code
tcode
  Code -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> ImpM GPUMem HostEnv HostOp ())
-> Code -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ case Maybe (TExp Bool)
check of
    Maybe (TExp Bool)
Nothing -> Code
fcode
    Just TExp Bool
ok -> TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If (TExp Bool
ok TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. SubExp -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp SubExp
cond) Code
tcode Code
fcode
expCompiler Pat GPUMem
dest Exp GPUMem
e =
  ExpCompiler GPUMem HostEnv HostOp
forall rep inner r op.
Mem rep inner =>
Pat rep -> Exp rep -> ImpM rep r op ()
defCompileExp Pat GPUMem
dest Exp GPUMem
e

callKernelCopy :: CopyCompiler GPUMem HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler GPUMem HostEnv HostOp
callKernelCopy PrimType
bt destloc :: MemLoc
destloc@(MemLoc VName
destmem [SubExp]
_ IxFun (TExp Int64)
destIxFun) srcloc :: MemLoc
srcloc@(MemLoc VName
srcmem [SubExp]
srcshape IxFun (TExp Int64)
srcIxFun)
  | Just (TExp Int64
destoffset, TExp Int64
srcoffset, TExp Int64
num_arrays, TExp Int64
size_x, TExp Int64
size_y) <-
      PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
bt MemLoc
destloc MemLoc
srcloc = do
    Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt
    Code -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> ImpM GPUMem HostEnv HostOp ())
-> Code -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        [ VName -> Arg
Imp.MemArg VName
destmem,
          Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset,
          VName -> Arg
Imp.MemArg VName
srcmem,
          Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset,
          Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_arrays,
          Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size_x,
          Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size_y
        ]
  | TExp Int64
bt_size <- PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt,
    Just TExp Int64
destoffset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
    Just TExp Int64
srcoffset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size = do
    let num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
srcshape
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem HostEnv HostOp MemEntry
-> ImpM GPUMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem HostEnv HostOp MemEntry
-> ImpM GPUMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
    Code -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> ImpM GPUMem HostEnv HostOp ())
-> Code -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
        VName
destmem
        (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
destoffset)
        Space
destspace
        VName
srcmem
        (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
srcoffset)
        Space
srcspace
        (Count Bytes (TExp Int64) -> Code)
-> Count Bytes (TExp Int64) -> Code
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` PrimType
bt
  | Bool
otherwise = CopyCompiler GPUMem HostEnv HostOp
sCopy PrimType
bt MemLoc
destloc MemLoc
srcloc

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = SpaceId -> Name
nameFromString (SpaceId -> Name) -> SpaceId -> Name
forall a b. (a -> b) -> a -> b
$ SpaceId
"builtin#" SpaceId -> SpaceId -> SpaceId
forall a. Semigroup a => a -> a -> a
<> PrimType -> SpaceId
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function HostOp -> ImpM GPUMem HostEnv HostOp ())
-> Function HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> SpaceId
mapTransposeName PrimType
bt = SpaceId
"gpu_map_transpose_" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PrimType -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  Maybe Name
-> [Param]
-> [Param]
-> Code
-> [ExternalValue]
-> [ExternalValue]
-> Function HostOp
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Maybe Name
forall a. Maybe a
Nothing [] [Param]
params Code
transpose_code [] []
  where
    params :: [Param]
params =
      [ VName -> Param
memparam VName
destmem,
        VName -> Param
intparam VName
destoffset,
        VName -> Param
memparam VName
srcmem,
        VName -> Param
intparam VName
srcoffset,
        VName -> Param
intparam VName
num_arrays,
        VName -> Param
intparam VName
x,
        VName -> Param
intparam VName
y
      ]

    space :: Space
space = SpaceId -> Space
Space SpaceId
"device"
    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32

    [ VName
destmem,
      VName
destoffset,
      VName
srcmem,
      VName
srcoffset,
      VName
num_arrays,
      VName
x,
      VName
y,
      VName
mulx,
      VName
muly,
      VName
block
      ] =
        (SpaceId -> Int -> VName) -> [SpaceId] -> [Int] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (Name -> Int -> VName
VName (Name -> Int -> VName)
-> (SpaceId -> Name) -> SpaceId -> Int -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId -> Name
nameFromString)
          [ SpaceId
"destmem",
            SpaceId
"destoffset",
            SpaceId
"srcmem",
            SpaceId
"srcoffset",
            SpaceId
"num_arrays",
            SpaceId
"x_elems",
            SpaceId
"y_elems",
            -- The following is only used for low width/height
            -- transpose kernels
            SpaceId
"mulx",
            SpaceId
"muly",
            SpaceId
"block"
          ]
          [Int
0 ..]

    block_dim_int :: Integer
block_dim_int = Integer
16

    block_dim :: IntegralExp a => a
    block_dim :: a
block_dim = a
16

    -- When an input array has either width==1 or height==1, performing a
    -- transpose will be the same as performing a copy.
    can_use_copy :: TExp Bool
can_use_copy =
      let onearr :: TExp Bool
onearr = VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
num_arrays TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
1
          height_is_one :: TExp Bool
height_is_one = VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
1
          width_is_one :: TExp Bool
width_is_one = VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
1
       in TExp Bool
onearr TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Bool
width_is_one TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
height_is_one)

    transpose_code :: Code
transpose_code =
      TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
input_is_empty Code
forall a. Monoid a => a
mempty (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
        [Code] -> Code
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Volatility -> PrimType -> Code
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32),
            VName -> Exp -> Code
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly (Exp -> Code) -> Exp -> Code
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 ExpLeaf -> Exp) -> TPrimExp Int32 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x,
            VName -> Volatility -> PrimType -> Code
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32),
            VName -> Exp -> Code
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx (Exp -> Code) -> Exp -> Code
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 ExpLeaf -> Exp) -> TPrimExp Int32 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y,
            TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
can_use_copy Code
copy_code (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
              TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowwidth (TransposeType -> Code
callTransposeKernel TransposeType
TransposeLowWidth) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
                TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowheight (TransposeType -> Code
callTransposeKernel TransposeType
TransposeLowHeight) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
                  TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_small (TransposeType -> Code
callTransposeKernel TransposeType
TransposeSmall) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
                    TransposeType -> Code
callTransposeKernel TransposeType
TransposeNormal
          ]

    input_is_empty :: TExp Bool
input_is_empty =
      VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
num_arrays TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0

    should_use_small :: TExp Bool
should_use_small =
      VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int32 ExpLeaf
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int32 ExpLeaf
2)

    should_use_lowwidth :: TExp Bool
should_use_lowwidth =
      VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int32 ExpLeaf
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y

    should_use_lowheight :: TExp Bool
should_use_lowheight =
      VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int32 ExpLeaf
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a
block_dim TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x

    copy_code :: Code
copy_code =
      let num_bytes :: TExp Int64
num_bytes = TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* PrimType -> TPrimExp Int32 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
       in VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
destoffset)
            Space
space
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
srcoffset)
            Space
space
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count TExp Int64
num_bytes)

    callTransposeKernel :: TransposeType -> Code
callTransposeKernel =
      HostOp -> Code
forall a. a -> Code a
Imp.Op (HostOp -> Code)
-> (TransposeType -> HostOp) -> TransposeType -> Code
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel
          (PrimType -> SpaceId
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
destoffset,
            VName
srcmem,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
srcoffset,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
x,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
y,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
mulx,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
muly,
            VName -> TPrimExp Int32 ExpLeaf
Imp.vi32 VName
num_arrays,
            VName
block
          )
          PrimType
bt