{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Compile a 'KernelsMem' program to imperative code with kernels.
-- This is mostly (but not entirely) the same process no matter if we
-- are targeting OpenCL or CUDA.  The important distinctions (the host
-- level code) are introduced later.
module Futhark.CodeGen.ImpGen.Kernels
  ( compileProgOpenCL
  , compileProgCUDA
  , Warnings
  )
  where

import Control.Monad.Except
import Data.Bifunctor (second)
import qualified Data.Map as M
import Data.Maybe
import Data.List (foldl')

import Prelude hiding (quot)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.IR.KernelsMem
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import Futhark.CodeGen.ImpGen hiding (compileProg)
import qualified Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.CodeGen.ImpGen.Kernels.SegMap
import Futhark.CodeGen.ImpGen.Kernels.SegRed
import Futhark.CodeGen.ImpGen.Kernels.SegScan
import Futhark.CodeGen.ImpGen.Kernels.SegHist
import Futhark.CodeGen.ImpGen.Kernels.Transpose
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.CodeGen.SetDefaultSpace
import Futhark.Util.IntegralExp (quot, divUp, IntegralExp)

callKernelOperations :: Operations KernelsMem HostEnv Imp.HostOp
callKernelOperations :: Operations KernelsMem HostEnv HostOp
callKernelOperations =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations { opsExpCompiler :: ExpCompiler KernelsMem HostEnv HostOp
opsExpCompiler = ExpCompiler KernelsMem HostEnv HostOp
expCompiler
             , opsCopyCompiler :: CopyCompiler KernelsMem HostEnv HostOp
opsCopyCompiler = CopyCompiler KernelsMem HostEnv HostOp
callKernelCopy
             , opsOpCompiler :: OpCompiler KernelsMem HostEnv HostOp
opsOpCompiler = OpCompiler KernelsMem HostEnv HostOp
opCompiler
             , opsStmsCompiler :: StmsCompiler KernelsMem HostEnv HostOp
opsStmsCompiler = StmsCompiler KernelsMem HostEnv HostOp
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms
             , opsAllocCompilers :: Map Space (AllocCompiler KernelsMem HostEnv HostOp)
opsAllocCompilers = Map Space (AllocCompiler KernelsMem HostEnv HostOp)
forall a. Monoid a => a
mempty
             }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = ((BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
 -> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl, (BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
 -> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
cuda)
  where opencl :: [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl = [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32)
                 , (IntType -> BinOp
SMax IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int32)
                 , (IntType -> BinOp
SMin IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int32)
                 , (IntType -> BinOp
UMax IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int32)
                 , (IntType -> BinOp
UMin IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int32)
                 , (IntType -> BinOp
And IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int32)
                 , (IntType -> BinOp
Or IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int32)
                 , (IntType -> BinOp
Xor IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int32)
                 ]
        cuda :: [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
cuda = [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [(FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float32)]

compileProg :: MonadFreshNames m => HostEnv -> Prog KernelsMem
            -> m (Warnings, Imp.Program)
compileProg :: HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg HostEnv
env Prog KernelsMem
prog =
  (Program -> Program) -> (Warnings, Program) -> (Warnings, Program)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Space -> Program -> Program
forall op. Space -> Definitions op -> Definitions op
setDefaultSpace (SpaceId -> Space
Imp.Space SpaceId
"device")) ((Warnings, Program) -> (Warnings, Program))
-> m (Warnings, Program) -> m (Warnings, Program)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  HostEnv
-> Operations KernelsMem HostEnv HostOp
-> Space
-> Prog KernelsMem
-> m (Warnings, Program)
forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations KernelsMem HostEnv HostOp
callKernelOperations (SpaceId -> Space
Imp.Space SpaceId
"device") Prog KernelsMem
prog

-- | Compile a 'KernelsMem' program to low-level parallel code, with
-- either CUDA or OpenCL characteristics.
compileProgOpenCL, compileProgCUDA
  :: MonadFreshNames m => Prog KernelsMem -> m (Warnings, Imp.Program)
compileProgOpenCL :: Prog KernelsMem -> m (Warnings, Program)
compileProgOpenCL = HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog KernelsMem -> m (Warnings, Program))
-> HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> HostEnv
HostEnv AtomicBinOp
openclAtomics
compileProgCUDA :: Prog KernelsMem -> m (Warnings, Program)
compileProgCUDA = HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog KernelsMem -> m (Warnings, Program))
-> HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> HostEnv
HostEnv AtomicBinOp
cudaAtomics

opCompiler :: Pattern KernelsMem -> Op KernelsMem
           -> CallKernelGen ()

opCompiler :: OpCompiler KernelsMem HostEnv HostOp
opCompiler Pattern KernelsMem
dest (Alloc e space) =
  Pattern KernelsMem
-> SubExp -> Space -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc Pattern KernelsMem
dest SubExp
e Space
space

opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (GetSize key size_class))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class

opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (CmpSizeLe key size_class x))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> (Exp -> HostOp) -> Exp -> ImpM KernelsMem HostEnv HostOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    (Exp -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp Exp
-> ImpM KernelsMem HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x

opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (GetSizeMax size_class))) =
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) SizeClass
size_class

opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (CalcNumGroups w64 max_num_groups_key group_size))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  VName
max_num_groups <- SpaceId -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. SpaceId -> PrimType -> ImpM lore r op VName
dPrim SpaceId
"max_num_groups" PrimType
int32
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
max_num_groups (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: Exp
num_groups_maybe_zero = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMin IntType
Int64)
                              (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64 SubExp
w64 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp`
                               IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
group_size)) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                              IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (VName -> Exp
Imp.vi32 VName
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: Exp
num_groups = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) Exp
1 Exp
num_groups_maybe_zero
  PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe VName -> Exp -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 Exp
num_groups

opCompiler Pattern KernelsMem
dest (Inner (SegOp op)) =
  Pattern KernelsMem
-> SegOp SegLevel KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
segOpCompiler Pattern KernelsMem
dest SegOp SegLevel KernelsMem
op

opCompiler Pattern KernelsMem
pat Op KernelsMem
e =
  SpaceId -> ImpM KernelsMem HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM KernelsMem HostEnv HostOp ())
-> SpaceId -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"opCompiler: Invalid pattern\n  " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++
  PatternT LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pattern KernelsMem
PatternT LetDecMem
pat SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"\nfor expression\n  " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ MemOp (HostOp KernelsMem ()) -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Op KernelsMem
MemOp (HostOp KernelsMem ())
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path Maybe Int32
def) =
  KernelPath -> Maybe Int32 -> SizeClass
Imp.SizeThreshold (((Name, Bool) -> (Name, Bool)) -> KernelPath -> KernelPath
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path) Maybe Int32
def
  where f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler :: Pattern KernelsMem -> SegOp SegLevel KernelsMem
              -> CallKernelGen ()
segOpCompiler :: Pattern KernelsMem
-> SegOp SegLevel KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
segOpCompiler Pattern KernelsMem
pat (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegMap Pattern KernelsMem
pat SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegRed lvl :: SegLevel
lvl@SegThread{} SegSpace
space [SegBinOp KernelsMem]
reds [Type]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegRed Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegScan lvl :: SegLevel
lvl@SegThread{} SegSpace
space [SegBinOp KernelsMem]
scans [Type]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegHist (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) SegSpace
space [HistOp KernelsMem]
ops [Type]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegHist Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp KernelsMem]
ops KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat SegOp SegLevel KernelsMem
segop =
  SpaceId -> ImpM KernelsMem HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM KernelsMem HostEnv HostOp ())
-> SpaceId -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"segOpCompiler: unexpected " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SegLevel -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty (SegOp SegLevel KernelsMem -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel KernelsMem
segop) SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" for rhs of pattern " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pattern KernelsMem
PatternT LetDecMem
pat

-- Create boolean expression that checks whether all kernels in the
-- enclosed code do not use more local memory than we have available.
-- We look at *all* the kernels here, even those that might be
-- otherwise protected by their own multi-versioning branches deeper
-- down.  Currently the compiler will not generate multi-versioning
-- that makes this a problem, but it might in the future.
checkLocalMemoryReqs :: Imp.Code -> CallKernelGen (Maybe Imp.Exp)
checkLocalMemoryReqs :: Code -> CallKernelGen (Maybe Exp)
checkLocalMemoryReqs Code
code = do
  Scope SOACS
scope <- ImpM KernelsMem HostEnv HostOp (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let alloc_sizes :: [Count Bytes Exp]
alloc_sizes = (Kernel -> Count Bytes Exp) -> [Kernel] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map ([Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> (Kernel -> [Count Bytes Exp]) -> Kernel -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Code KernelOp -> [Count Bytes Exp]
localAllocSizes (Code KernelOp -> [Count Bytes Exp])
-> (Kernel -> Code KernelOp) -> Kernel -> [Count Bytes Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> Code KernelOp
Imp.kernelBody) ([Kernel] -> [Count Bytes Exp]) -> [Kernel] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$ Code -> [Kernel]
getKernels Code
code

  -- If any of the sizes involve a variable that is not known at this
  -- point, then we cannot check the requirements.
  if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Scope SOACS -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` Scope SOACS
scope) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Names
forall a. FreeIn a => a -> Names
freeIn [Count Bytes Exp]
alloc_sizes)
    then Maybe Exp -> CallKernelGen (Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Exp
forall a. Maybe a
Nothing
    else do
    VName
local_memory_capacity <- SpaceId -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. SpaceId -> PrimType -> ImpM lore r op VName
dPrim SpaceId
"local_memory_capacity" PrimType
int32
    HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
local_memory_capacity SizeClass
SizeLocalMemory

    let local_memory_capacity_64 :: Exp
local_memory_capacity_64 =
          IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
local_memory_capacity
        fits :: Count Bytes Exp -> Exp
fits Count Bytes Exp
size =
          Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount Count Bytes Exp
size Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
local_memory_capacity_64
    Maybe Exp -> CallKernelGen (Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Exp -> CallKernelGen (Maybe Exp))
-> Maybe Exp -> CallKernelGen (Maybe Exp)
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) Exp
forall v. PrimExp v
true ((Count Bytes Exp -> Exp) -> [Count Bytes Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes Exp -> Exp
fits [Count Bytes Exp]
alloc_sizes)

  where getKernels :: Code -> [Kernel]
getKernels = (HostOp -> [Kernel]) -> Code -> [Kernel]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HostOp -> [Kernel]
getKernel
        getKernel :: HostOp -> [Kernel]
getKernel (Imp.CallKernel Kernel
k) = [Kernel
k]
        getKernel HostOp
_ = []

        localAllocSizes :: Code KernelOp -> [Count Bytes Exp]
localAllocSizes = (KernelOp -> [Count Bytes Exp])
-> Code KernelOp -> [Count Bytes Exp]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelOp -> [Count Bytes Exp]
localAllocSize
        localAllocSize :: KernelOp -> [Count Bytes Exp]
localAllocSize (Imp.LocalAlloc VName
_ Count Bytes Exp
size) = [Count Bytes Exp
size]
        localAllocSize KernelOp
_ = []

expCompiler :: ExpCompiler KernelsMem HostEnv Imp.HostOp

-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler KernelsMem HostEnv HostOp
expCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
et)) = do
  Exp
n' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
n
  Exp
x' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
s' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s

  VName
-> Exp
-> Exp
-> Exp
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIota (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) Exp
n' Exp
x' Exp
s' IntType
et

expCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (BasicOp (Replicate Shape
_ SubExp
se)) =
  VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicate (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) SubExp
se

-- Allocation in the "local" space is just a placeholder.
expCompiler Pattern KernelsMem
_ (Op (Alloc _ (Space "local"))) =
  () -> ImpM KernelsMem HostEnv HostOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- This is a multi-versioning If created by incremental flattening.
-- We need to augment the conditional with a check that any local
-- memory requirements in tbranch are compatible with the hardware.
-- We do not check anything for fbranch, as we assume that it will
-- always be safe (and what would we do if none of the branches would
-- work?).
expCompiler Pattern KernelsMem
dest (If SubExp
cond BodyT KernelsMem
tbranch BodyT KernelsMem
fbranch (IfDec [BranchType KernelsMem]
_ IfSort
IfEquiv)) = do
  Code
tcode <- ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp Code)
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp Code
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> BodyT KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern KernelsMem
dest BodyT KernelsMem
tbranch
  Code
fcode <- ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp Code)
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp Code
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> BodyT KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern KernelsMem
dest BodyT KernelsMem
fbranch
  Maybe Exp
check <- Code -> CallKernelGen (Maybe Exp)
checkLocalMemoryReqs Code
tcode
  Code -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM KernelsMem HostEnv HostOp ())
-> Code -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ case Maybe Exp
check of
           Maybe Exp
Nothing -> Code
fcode
           Just Exp
ok -> Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If (Exp
ok Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool SubExp
cond) Code
tcode Code
fcode

expCompiler Pattern KernelsMem
dest ExpT KernelsMem
e =
  ExpCompiler KernelsMem HostEnv HostOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest ExpT KernelsMem
e

callKernelCopy :: CopyCompiler KernelsMem HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler KernelsMem HostEnv HostOp
callKernelCopy PrimType
bt
  destloc :: MemLocation
destloc@(MemLocation VName
destmem [SubExp]
_ IxFun Exp
destIxFun) Slice Exp
destslice
  srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [SubExp]
srcshape IxFun Exp
srcIxFun) Slice Exp
srcslice
  | Just (Exp
destoffset, Exp
srcoffset,
          Exp
num_arrays, Exp
size_x, Exp
size_y) <-
      PrimType
-> MemLocation
-> Slice Exp
-> MemLocation
-> Slice Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp)
isMapTransposeKernel PrimType
bt MemLocation
destloc Slice Exp
destslice MemLocation
srcloc Slice Exp
srcslice = do

      Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt
      Code -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM KernelsMem HostEnv HostOp ())
-> Code -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname
        [VName -> Arg
Imp.MemArg VName
destmem, Exp -> Arg
Imp.ExpArg Exp
destoffset,
         VName -> Arg
Imp.MemArg VName
srcmem, Exp -> Arg
Imp.ExpArg Exp
srcoffset,
         Exp -> Arg
Imp.ExpArg Exp
num_arrays, Exp -> Arg
Imp.ExpArg Exp
size_x, Exp -> Arg
Imp.ExpArg Exp
size_y]

  | Exp
bt_size <- PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt,
    Just Exp
destoffset <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun Exp -> Slice Exp -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
destIxFun Slice Exp
destslice) Exp
bt_size,
    Just Exp
srcoffset  <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun Exp -> Slice Exp -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
srcIxFun Slice Exp
srcslice) Exp
bt_size = do
        let num_elems :: Count Elements Exp
num_elems = Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
srcshape
        Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem HostEnv HostOp MemEntry
-> ImpM KernelsMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
        Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem HostEnv HostOp MemEntry
-> ImpM KernelsMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
        Code -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM KernelsMem HostEnv HostOp ())
-> Code -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
          VName
destmem (Exp -> Count Bytes Exp
bytes Exp
destoffset) Space
destspace
          VName
srcmem (Exp -> Count Bytes Exp
bytes Exp
srcoffset) Space
srcspace (Count Bytes Exp -> Code) -> Count Bytes Exp -> Code
forall a b. (a -> b) -> a -> b
$
          Count Elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`Imp.withElemType` PrimType
bt

  | Bool
otherwise = CopyCompiler KernelsMem HostEnv HostOp
sCopy PrimType
bt MemLocation
destloc Slice Exp
destslice MemLocation
srcloc Slice Exp
srcslice

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = SpaceId -> Name
nameFromString (SpaceId -> Name) -> SpaceId -> Name
forall a b. (a -> b) -> a -> b
$ SpaceId
"builtin#" SpaceId -> SpaceId -> SpaceId
forall a. Semigroup a => a -> a -> a
<> PrimType -> SpaceId
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Function HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> SpaceId
mapTransposeName PrimType
bt = SpaceId
"map_transpose_" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PrimType -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  Bool
-> [Param]
-> [Param]
-> Code
-> [ExternalValue]
-> [ExternalValue]
-> Function HostOp
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [] [Param]
params Code
transpose_code [] []

  where params :: [Param]
params = [VName -> Param
memparam VName
destmem, VName -> Param
intparam VName
destoffset,
                  VName -> Param
memparam VName
srcmem, VName -> Param
intparam VName
srcoffset,
                  VName -> Param
intparam VName
num_arrays, VName -> Param
intparam VName
x, VName -> Param
intparam VName
y]

        space :: Space
space = SpaceId -> Space
Space SpaceId
"device"
        memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
        intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32

        [VName
destmem, VName
destoffset, VName
srcmem, VName
srcoffset,
         VName
num_arrays, VName
x, VName
y,
         VName
mulx, VName
muly, VName
block] =
           (SpaceId -> Int -> VName) -> [SpaceId] -> [Int] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Name -> Int -> VName
VName (Name -> Int -> VName)
-> (SpaceId -> Name) -> SpaceId -> Int -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId -> Name
nameFromString)
           [SpaceId
"destmem",
             SpaceId
"destoffset",
             SpaceId
"srcmem",
             SpaceId
"srcoffset",
             SpaceId
"num_arrays",
             SpaceId
"x_elems",
             SpaceId
"y_elems",
             -- The following is only used for low width/height
             -- transpose kernels
             SpaceId
"mulx",
             SpaceId
"muly",
             SpaceId
"block"
            ]
           [Int
0..]

        v32 :: VName -> Exp
v32 VName
v = VName -> PrimType -> Exp
Imp.var VName
v PrimType
int32

        block_dim_int :: Integer
block_dim_int = Integer
16

        block_dim :: IntegralExp a => a
        block_dim :: a
block_dim = a
16

        -- When an input array has either width==1 or height==1, performing a
        -- transpose will be the same as performing a copy.
        can_use_copy :: Exp
can_use_copy =
          let onearr :: Exp
onearr = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
num_arrays) Exp
1
              height_is_one :: Exp
height_is_one = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
y) Exp
1
              width_is_one :: Exp
width_is_one = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
x) Exp
1
          in Exp
onearr Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
width_is_one Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
height_is_one)

        transpose_code :: Code
transpose_code =
          Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
input_is_empty Code
forall a. Monoid a => a
mempty (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$ [Code] -> Code
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Volatility -> PrimType -> Code
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32)
          , VName -> Exp -> Code
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly (Exp -> Code) -> Exp -> Code
forall a b. (a -> b) -> a -> b
$ Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` VName -> Exp
v32 VName
x
          , VName -> Volatility -> PrimType -> Code
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32)
          , VName -> Exp -> Code
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx (Exp -> Code) -> Exp -> Code
forall a b. (a -> b) -> a -> b
$ Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` VName -> Exp
v32 VName
y
          , Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
can_use_copy Code
copy_code (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
            Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_lowwidth (TransposeType -> Code
callTransposeKernel TransposeType
TransposeLowWidth) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
            Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_lowheight (TransposeType -> Code
callTransposeKernel TransposeType
TransposeLowHeight) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
            Exp -> Code -> Code -> Code
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_small (TransposeType -> Code
callTransposeKernel TransposeType
TransposeSmall) (Code -> Code) -> Code -> Code
forall a b. (a -> b) -> a -> b
$
            TransposeType -> Code
callTransposeKernel TransposeType
TransposeNormal]

        input_is_empty :: Exp
input_is_empty =
          VName -> Exp
v32 VName
num_arrays Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. VName -> Exp
v32 VName
x Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. VName -> Exp
v32 VName
y Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0

        should_use_small :: Exp
should_use_small = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
x) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
y) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
2))

        should_use_lowwidth :: Exp
should_use_lowwidth = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
x) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSlt IntType
Int32) Exp
forall a. IntegralExp a => a
block_dim (VName -> Exp
v32 VName
y))

        should_use_lowheight :: Exp
should_use_lowheight = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
y) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSlt IntType
Int32) Exp
forall a. IntegralExp a => a
block_dim (VName -> Exp
v32 VName
x))

        copy_code :: Code
copy_code =
          let num_bytes :: Exp
num_bytes =
                VName -> Exp
v32 VName
x Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
v32 VName
y Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf PrimType
bt) (IntType -> PrimType
IntType IntType
Int32)
          in VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
               VName
destmem (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
destoffset) Space
space
               VName
srcmem (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
srcoffset) Space
space
               (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count Exp
num_bytes)

        callTransposeKernel :: TransposeType -> Code
callTransposeKernel =
          HostOp -> Code
forall a. a -> Code a
Imp.Op (HostOp -> Code)
-> (TransposeType -> HostOp) -> TransposeType -> Code
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
          SpaceId
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel (PrimType -> SpaceId
mapTransposeName PrimType
bt) Integer
block_dim_int
          (VName
destmem, VName -> Exp
v32 VName
destoffset, VName
srcmem, VName -> Exp
v32 VName
srcoffset,
            VName -> Exp
v32 VName
x, VName -> Exp
v32 VName
y,
            VName -> Exp
v32 VName
mulx, VName -> Exp
v32 VName
muly, VName -> Exp
v32 VName
num_arrays,
            VName
block) PrimType
bt

isMapTransposeKernel :: PrimType
                     -> MemLocation -> Slice Imp.Exp
                     -> MemLocation -> Slice Imp.Exp
                     -> Maybe (Imp.Exp, Imp.Exp,
                               Imp.Exp, Imp.Exp, Imp.Exp)
isMapTransposeKernel :: PrimType
-> MemLocation
-> Slice Exp
-> MemLocation
-> Slice Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp)
isMapTransposeKernel PrimType
bt
  (MemLocation VName
_ [SubExp]
_ IxFun Exp
destIxFun) Slice Exp
destslice
  (MemLocation VName
_ [SubExp]
_ IxFun Exp
srcIxFun) Slice Exp
srcslice
  | Just (Exp
dest_offset, [(Int, Exp)]
perm_and_destshape) <- IxFun Exp -> Exp -> Maybe (Exp, [(Int, Exp)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun Exp
destIxFun' Exp
bt_size,
    ([Int]
perm, [Exp]
destshape) <- [(Int, Exp)] -> ([Int], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, Exp)]
perm_and_destshape,
    Just Exp
src_offset <- IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
srcIxFun' Exp
bt_size,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      [Exp]
-> (([Exp], [Exp]) -> ([Exp], [Exp]))
-> Int
-> Int
-> Exp
-> Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [Exp]
destshape ([Exp], [Exp]) -> ([Exp], [Exp])
forall b a. (b, a) -> (a, b)
swap Int
r1 Int
r2 Exp
dest_offset Exp
src_offset
  | Just Exp
dest_offset <- IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
destIxFun' Exp
bt_size,
    Just (Exp
src_offset, [(Int, Exp)]
perm_and_srcshape) <- IxFun Exp -> Exp -> Maybe (Exp, [(Int, Exp)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun Exp
srcIxFun' Exp
bt_size,
    ([Int]
perm, [Exp]
srcshape) <- [(Int, Exp)] -> ([Int], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, Exp)]
perm_and_srcshape,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      [Exp]
-> (([Exp], [Exp]) -> ([Exp], [Exp]))
-> Int
-> Int
-> Exp
-> Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [Exp]
srcshape ([Exp], [Exp]) -> ([Exp], [Exp])
forall a. a -> a
id Int
r1 Int
r2 Exp
dest_offset Exp
src_offset
  | Bool
otherwise =
      Maybe (Exp, Exp, Exp, Exp, Exp)
forall a. Maybe a
Nothing
  where bt_size :: Exp
bt_size = PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
        swap :: (b, a) -> (a, b)
swap (b
x,a
y) = (a
y,b
x)

        destIxFun' :: IxFun Exp
destIxFun' = IxFun Exp -> Slice Exp -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
destIxFun Slice Exp
destslice
        srcIxFun' :: IxFun Exp
srcIxFun' = IxFun Exp -> Slice Exp -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
srcIxFun Slice Exp
srcslice

        isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
          let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall (t :: * -> *) (t :: * -> *) a b c.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
          (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
dest_offset, b
src_offset,
                  c
num_arrays, d
size_x, e
size_y)

        getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
          let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
              (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
          in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)