{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen.Kernels
  ( compileProgOpenCL
  , compileProgCUDA
  )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List ()

import Prelude hiding (quot)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import Futhark.CodeGen.ImpGen hiding (compileProg)
import qualified Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.CodeGen.ImpGen.Kernels.SegMap
import Futhark.CodeGen.ImpGen.Kernels.SegRed
import Futhark.CodeGen.ImpGen.Kernels.SegScan
import Futhark.CodeGen.ImpGen.Kernels.SegHist
import Futhark.CodeGen.ImpGen.Kernels.Transpose
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.CodeGen.SetDefaultSpace
import Futhark.Util.IntegralExp (quot, quotRoundingUp, IntegralExp)

callKernelOperations :: Operations ExplicitMemory HostEnv Imp.HostOp
callKernelOperations :: Operations ExplicitMemory HostEnv HostOp
callKernelOperations =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations { opsExpCompiler :: ExpCompiler ExplicitMemory HostEnv HostOp
opsExpCompiler = ExpCompiler ExplicitMemory HostEnv HostOp
expCompiler
             , opsCopyCompiler :: CopyCompiler ExplicitMemory HostEnv HostOp
opsCopyCompiler = CopyCompiler ExplicitMemory HostEnv HostOp
callKernelCopy
             , opsOpCompiler :: OpCompiler ExplicitMemory HostEnv HostOp
opsOpCompiler = OpCompiler ExplicitMemory HostEnv HostOp
opCompiler
             , opsStmsCompiler :: StmsCompiler ExplicitMemory HostEnv HostOp
opsStmsCompiler = StmsCompiler ExplicitMemory HostEnv HostOp
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms
             , opsAllocCompilers :: Map Space (AllocCompiler ExplicitMemory HostEnv HostOp)
opsAllocCompilers = Map Space (AllocCompiler ExplicitMemory HostEnv HostOp)
forall a. Monoid a => a
mempty
             }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = ((BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
 -> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl, (BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
 -> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
cuda)
  where opencl :: [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl = [ (IntType -> BinOp
Add IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32)
                 , (IntType -> BinOp
SMax IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int32)
                 , (IntType -> BinOp
SMin IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int32)
                 , (IntType -> BinOp
UMax IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int32)
                 , (IntType -> BinOp
UMin IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int32)
                 , (IntType -> BinOp
And IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int32)
                 , (IntType -> BinOp
Or IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int32)
                 , (IntType -> BinOp
Xor IntType
Int32, IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int32)
                 ]
        cuda :: [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
cuda = [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
opencl [(BinOp, VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [(FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float32)]

compileProg :: MonadFreshNames m => HostEnv -> Prog ExplicitMemory -> m Imp.Program
compileProg :: HostEnv -> Prog ExplicitMemory -> m Program
compileProg HostEnv
env Prog ExplicitMemory
prog =
  Space -> Program -> Program
forall op. Space -> Definitions op -> Definitions op
setDefaultSpace (SpaceId -> Space
Imp.Space SpaceId
"device") (Program -> Program) -> m Program -> m Program
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  HostEnv
-> Operations ExplicitMemory HostEnv HostOp
-> Space
-> Prog ExplicitMemory
-> m Program
forall lore op (m :: * -> *) r.
(ExplicitMemorish lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op -> Space -> Prog lore -> m (Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations ExplicitMemory HostEnv HostOp
callKernelOperations (SpaceId -> Space
Imp.Space SpaceId
"device") Prog ExplicitMemory
prog

compileProgOpenCL, compileProgCUDA
  :: MonadFreshNames m => Prog ExplicitMemory -> m Imp.Program
compileProgOpenCL :: Prog ExplicitMemory -> m Program
compileProgOpenCL = HostEnv -> Prog ExplicitMemory -> m Program
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog ExplicitMemory -> m Program
compileProg (HostEnv -> Prog ExplicitMemory -> m Program)
-> HostEnv -> Prog ExplicitMemory -> m Program
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> HostEnv
HostEnv AtomicBinOp
openclAtomics
compileProgCUDA :: Prog ExplicitMemory -> m Program
compileProgCUDA = HostEnv -> Prog ExplicitMemory -> m Program
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog ExplicitMemory -> m Program
compileProg (HostEnv -> Prog ExplicitMemory -> m Program)
-> HostEnv -> Prog ExplicitMemory -> m Program
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> HostEnv
HostEnv AtomicBinOp
cudaAtomics

opCompiler :: Pattern ExplicitMemory -> Op ExplicitMemory
           -> CallKernelGen ()

opCompiler :: OpCompiler ExplicitMemory HostEnv HostOp
opCompiler Pattern ExplicitMemory
dest (Alloc e space) =
  Pattern ExplicitMemory
-> SubExp -> Space -> ImpM ExplicitMemory HostEnv HostOp ()
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc Pattern ExplicitMemory
dest SubExp
e Space
space

opCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (Inner (SizeOp (GetSize key size_class))) = do
  Maybe Name
fname <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class

opCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (Inner (SizeOp (CmpSizeLe key size_class x))) = do
  Maybe Name
fname <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> (Exp -> HostOp) -> Exp -> ImpM ExplicitMemory HostEnv HostOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    (Exp -> ImpM ExplicitMemory HostEnv HostOp ())
-> ImpM ExplicitMemory HostEnv HostOp Exp
-> ImpM ExplicitMemory HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x

opCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (Inner (SizeOp (GetSizeMax size_class))) =
  HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) SizeClass
size_class

opCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (Inner (SizeOp (CalcNumGroups w64 max_num_groups_key group_size))) = do
  Maybe Name
fname <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  VName
max_num_groups <- SpaceId -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. SpaceId -> PrimType -> ImpM lore r op VName
dPrim SpaceId
"max_num_groups" PrimType
int32
  HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
max_num_groups (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: Exp
num_groups_maybe_zero = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMin IntType
Int64)
                              (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64 SubExp
w64 Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp`
                               Exp -> Exp
forall v. PrimExp v -> PrimExp v
i64 (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
group_size)) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                              Exp -> Exp
forall v. PrimExp v -> PrimExp v
i64 (VName -> Exp
Imp.vi32 VName
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: Exp
num_groups = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) Exp
1 Exp
num_groups_maybe_zero
  PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe VName -> Exp -> ImpM ExplicitMemory HostEnv HostOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32 Exp
num_groups

  where i64 :: PrimExp v -> PrimExp v
i64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64)
        i32 :: PrimExp v -> PrimExp v
i32 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)

opCompiler Pattern ExplicitMemory
dest (Inner (SegOp op)) =
  Pattern ExplicitMemory
-> SegOp ExplicitMemory -> ImpM ExplicitMemory HostEnv HostOp ()
segOpCompiler Pattern ExplicitMemory
dest SegOp ExplicitMemory
op

opCompiler Pattern ExplicitMemory
pat Op ExplicitMemory
e =
  SpaceId -> ImpM ExplicitMemory HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM ExplicitMemory HostEnv HostOp ())
-> SpaceId -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"opCompiler: Invalid pattern\n  " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++
  PatternT (MemInfo SubExp NoUniqueness MemBind) -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"\nfor expression\n  " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ MemOp (HostOp ExplicitMemory ()) -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Op ExplicitMemory
MemOp (HostOp ExplicitMemory ())
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path) =
  KernelPath -> SizeClass
Imp.SizeThreshold (KernelPath -> SizeClass) -> KernelPath -> SizeClass
forall a b. (a -> b) -> a -> b
$ ((Name, Bool) -> (Name, Bool)) -> KernelPath -> KernelPath
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path
  where f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler :: Pattern ExplicitMemory -> SegOp ExplicitMemory -> CallKernelGen ()
segOpCompiler :: Pattern ExplicitMemory
-> SegOp ExplicitMemory -> ImpM ExplicitMemory HostEnv HostOp ()
segOpCompiler Pattern ExplicitMemory
pat (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody ExplicitMemory
kbody) =
  Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> KernelBody ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp ()
compileSegMap Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space KernelBody ExplicitMemory
kbody
segOpCompiler Pattern ExplicitMemory
pat (SegRed lvl :: SegLevel
lvl@SegThread{} SegSpace
space [SegRedOp ExplicitMemory]
reds [Type]
_ KernelBody ExplicitMemory
kbody) =
  Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> KernelBody ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp ()
compileSegRed Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space [SegRedOp ExplicitMemory]
reds KernelBody ExplicitMemory
kbody
segOpCompiler Pattern ExplicitMemory
pat (SegScan lvl :: SegLevel
lvl@SegThread{} SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes [Type]
_ KernelBody ExplicitMemory
kbody) =
  Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp ()
compileSegScan Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody
segOpCompiler Pattern ExplicitMemory
pat (SegHist (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) SegSpace
space [HistOp ExplicitMemory]
ops [Type]
_ KernelBody ExplicitMemory
kbody) =
  Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp ExplicitMemory]
-> KernelBody ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp ()
compileSegHist Pattern ExplicitMemory
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp ExplicitMemory]
ops KernelBody ExplicitMemory
kbody
segOpCompiler Pattern ExplicitMemory
pat SegOp ExplicitMemory
segop =
  SpaceId -> ImpM ExplicitMemory HostEnv HostOp ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> ImpM ExplicitMemory HostEnv HostOp ())
-> SpaceId -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"segOpCompiler: unexpected " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SegLevel -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty (SegOp ExplicitMemory -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp ExplicitMemory
segop) SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" for rhs of pattern " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

expCompiler :: ExpCompiler ExplicitMemory HostEnv Imp.HostOp

-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler ExplicitMemory HostEnv HostOp
expCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
et)) = do
  Exp
n' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
n
  Exp
x' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
s' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s

  VName
-> Exp
-> Exp
-> Exp
-> IntType
-> ImpM ExplicitMemory HostEnv HostOp ()
sIota (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) Exp
n' Exp
x' Exp
s' IntType
et

expCompiler (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
pe]) (BasicOp (Replicate Shape
_ SubExp
se)) =
  VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicate (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) SubExp
se

-- Allocation in the "local" space is just a placeholder.
expCompiler Pattern ExplicitMemory
_ (Op (Alloc _ (Space "local"))) =
  () -> ImpM ExplicitMemory HostEnv HostOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

expCompiler Pattern ExplicitMemory
dest ExpT ExplicitMemory
e =
  ExpCompiler ExplicitMemory HostEnv HostOp
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e

callKernelCopy :: CopyCompiler ExplicitMemory HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler ExplicitMemory HostEnv HostOp
callKernelCopy PrimType
bt
  destloc :: MemLocation
destloc@(MemLocation VName
destmem [SubExp]
_ IxFun Exp
destIxFun)
  srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [SubExp]
srcshape IxFun Exp
srcIxFun)
  | Just (Exp
destoffset, Exp
srcoffset,
          Exp
num_arrays, Exp
size_x, Exp
size_y,
          Exp
src_elems, Exp
dest_elems) <- PrimType
-> MemLocation
-> MemLocation
-> Maybe (Exp, Exp, Exp, Exp, Exp, Exp, Exp)
isMapTransposeKernel PrimType
bt MemLocation
destloc MemLocation
srcloc = do

      Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt
      Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname
        [VName -> Arg
Imp.MemArg VName
destmem, Exp -> Arg
Imp.ExpArg Exp
destoffset,
         VName -> Arg
Imp.MemArg VName
srcmem, Exp -> Arg
Imp.ExpArg Exp
srcoffset,
         Exp -> Arg
Imp.ExpArg Exp
num_arrays, Exp -> Arg
Imp.ExpArg Exp
size_x, Exp -> Arg
Imp.ExpArg Exp
size_y,
         Exp -> Arg
Imp.ExpArg Exp
src_elems, Exp -> Arg
Imp.ExpArg Exp
dest_elems]

  | Exp
bt_size <- PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt,
    Just Exp
destoffset <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
destIxFun Exp
bt_size,
    Just Exp
srcoffset  <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
srcIxFun Exp
bt_size = do
        let num_elems :: Count Elements Exp
num_elems = Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
srcshape
        Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory HostEnv HostOp MemEntry
-> ImpM ExplicitMemory HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
        Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory HostEnv HostOp MemEntry
-> ImpM ExplicitMemory HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
        Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code HostOp
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
          VName
destmem (Exp -> Count Bytes Exp
bytes Exp
destoffset) Space
destspace
          VName
srcmem (Exp -> Count Bytes Exp
bytes Exp
srcoffset) Space
srcspace (Count Bytes Exp -> Code HostOp) -> Count Bytes Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
          Count Elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`Imp.withElemType` PrimType
bt

  | Bool
otherwise = CopyCompiler ExplicitMemory HostEnv HostOp
sCopy PrimType
bt MemLocation
destloc MemLocation
srcloc

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = SpaceId -> Name
nameFromString (SpaceId -> Name) -> SpaceId -> Name
forall a b. (a -> b) -> a -> b
$ SpaceId
"builtin#" SpaceId -> SpaceId -> SpaceId
forall a. Semigroup a => a -> a -> a
<> PrimType -> SpaceId
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM ExplicitMemory HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM ExplicitMemory HostEnv HostOp ()
 -> ImpM ExplicitMemory HostEnv HostOp ())
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Function HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> SpaceId
mapTransposeName PrimType
bt = SpaceId
"map_transpose_" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PrimType -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  Bool
-> [Param]
-> [Param]
-> Code HostOp
-> [ExternalValue]
-> [ExternalValue]
-> Function HostOp
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [] [Param]
params Code HostOp
transpose_code [] []

  where params :: [Param]
params = [VName -> Param
memparam VName
destmem, VName -> Param
intparam VName
destoffset,
                  VName -> Param
memparam VName
srcmem, VName -> Param
intparam VName
srcoffset,
                  VName -> Param
intparam VName
num_arrays, VName -> Param
intparam VName
x, VName -> Param
intparam VName
y,
                  VName -> Param
intparam VName
in_elems, VName -> Param
intparam VName
out_elems]

        space :: Space
space = SpaceId -> Space
Space SpaceId
"device"
        memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
        intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32

        [VName
destmem, VName
destoffset, VName
srcmem, VName
srcoffset,
         VName
num_arrays, VName
x, VName
y, VName
in_elems, VName
out_elems,
         VName
mulx, VName
muly, VName
block] =
           (SpaceId -> Int -> VName) -> [SpaceId] -> [Int] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Name -> Int -> VName
VName (Name -> Int -> VName)
-> (SpaceId -> Name) -> SpaceId -> Int -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId -> Name
nameFromString)
           [SpaceId
"destmem",
             SpaceId
"destoffset",
             SpaceId
"srcmem",
             SpaceId
"srcoffset",
             SpaceId
"num_arrays",
             SpaceId
"x_elems",
             SpaceId
"y_elems",
             SpaceId
"in_elems",
             SpaceId
"out_elems",
             -- The following is only used for low width/height
             -- transpose kernels
             SpaceId
"mulx",
             SpaceId
"muly",
             SpaceId
"block"
            ]
           [Int
0..]

        v32 :: VName -> Exp
v32 VName
v = VName -> PrimType -> Exp
Imp.var VName
v PrimType
int32

        block_dim_int :: Integer
block_dim_int = Integer
16

        block_dim :: IntegralExp a => a
        block_dim :: a
block_dim = a
16

        -- When an input array has either width==1 or height==1, performing a
        -- transpose will be the same as performing a copy.  If 'input_size' or
        -- 'output_size' is not equal to width*height, then this trick will not
        -- work when there are more than one array to process, as it is a per
        -- array limit. We could copy each array individually, but currently we
        -- do not.
        can_use_copy :: Exp
can_use_copy =
          let in_out_eq :: Exp
in_out_eq = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
in_elems) (VName -> Exp
v32 VName
out_elems)
              onearr :: Exp
onearr = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
num_arrays) Exp
1
              noprob_widthheight :: Exp
noprob_widthheight = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32)
                                     (VName -> Exp
v32 VName
x Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
v32 VName
y)
                                     (VName -> Exp
v32 VName
in_elems)
              height_is_one :: Exp
height_is_one = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
y) Exp
1
              width_is_one :: Exp
width_is_one = CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq (PrimType -> CmpOp) -> PrimType -> CmpOp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32) (VName -> Exp
v32 VName
x) Exp
1
          in BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
               Exp
in_out_eq
               (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
                 (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogOr Exp
onearr Exp
noprob_widthheight)
                 (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogOr Exp
width_is_one Exp
height_is_one))

        transpose_code :: Code HostOp
transpose_code =
          Exp -> Code HostOp -> Code HostOp -> Code HostOp
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
input_is_empty Code HostOp
forall a. Monoid a => a
mempty (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ [Code HostOp] -> Code HostOp
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32)
          , VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` VName -> Exp
v32 VName
x
          , VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32)
          , VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` VName -> Exp
v32 VName
y
          , Exp -> Code HostOp -> Code HostOp -> Code HostOp
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
can_use_copy Code HostOp
copy_code (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            Exp -> Code HostOp -> Code HostOp -> Code HostOp
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_lowwidth (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowWidth) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            Exp -> Code HostOp -> Code HostOp -> Code HostOp
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_lowheight (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowHeight) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            Exp -> Code HostOp -> Code HostOp -> Code HostOp
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
should_use_small (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeSmall) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeNormal]

        input_is_empty :: Exp
input_is_empty =
          VName -> Exp
v32 VName
num_arrays Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. VName -> Exp
v32 VName
x Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. VName -> Exp
v32 VName
y Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0

        should_use_small :: Exp
should_use_small = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
x) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
y) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
2))

        should_use_lowwidth :: Exp
should_use_lowwidth = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
x) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSlt IntType
Int32) Exp
forall a. IntegralExp a => a
block_dim (VName -> Exp
v32 VName
y))

        should_use_lowheight :: Exp
should_use_lowheight = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
LogAnd
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSle IntType
Int32) (VName -> Exp
v32 VName
y) (Exp
forall a. IntegralExp a => a
block_dim Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
2))
          (CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (IntType -> CmpOp
CmpSlt IntType
Int32) Exp
forall a. IntegralExp a => a
block_dim (VName -> Exp
v32 VName
x))

        copy_code :: Code HostOp
copy_code =
          let num_bytes :: Exp
num_bytes =
                VName -> Exp
v32 VName
in_elems Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf PrimType
bt) (IntType -> PrimType
IntType IntType
Int32)
          in VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code HostOp
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
               VName
destmem (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
destoffset) Space
space
               VName
srcmem (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
srcoffset) Space
space
               (Exp -> Count Bytes Exp
forall u e. e -> Count u e
Imp.Count Exp
num_bytes)

        callTransposeKernel :: TransposeType -> Code HostOp
callTransposeKernel =
          HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp)
-> (TransposeType -> HostOp) -> TransposeType -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
          SpaceId
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel (PrimType -> SpaceId
mapTransposeName PrimType
bt) Integer
block_dim_int
          (VName
destmem, VName -> Exp
v32 VName
destoffset, VName
srcmem, VName -> Exp
v32 VName
srcoffset,
            VName -> Exp
v32 VName
x, VName -> Exp
v32 VName
y, VName -> Exp
v32 VName
in_elems, VName -> Exp
v32 VName
out_elems,
            VName -> Exp
v32 VName
mulx, VName -> Exp
v32 VName
muly, VName -> Exp
v32 VName
num_arrays,
            VName
block) PrimType
bt

isMapTransposeKernel :: PrimType -> MemLocation -> MemLocation
                     -> Maybe (Imp.Exp, Imp.Exp,
                               Imp.Exp, Imp.Exp, Imp.Exp,
                               Imp.Exp, Imp.Exp)
isMapTransposeKernel :: PrimType
-> MemLocation
-> MemLocation
-> Maybe (Exp, Exp, Exp, Exp, Exp, Exp, Exp)
isMapTransposeKernel PrimType
bt
  (MemLocation VName
_ [SubExp]
_ IxFun Exp
destIxFun)
  (MemLocation VName
_ [SubExp]
_ IxFun Exp
srcIxFun)
  | Just (Exp
dest_offset, [(Int, Exp)]
perm_and_destshape) <- IxFun Exp -> Exp -> Maybe (Exp, [(Int, Exp)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun Exp
destIxFun Exp
bt_size,
    ([Int]
perm, [Exp]
destshape) <- [(Int, Exp)] -> ([Int], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, Exp)]
perm_and_destshape,
    Just Exp
src_offset <- IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
srcIxFun Exp
bt_size,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Exp
-> [Exp]
-> (([Exp], [Exp]) -> ([Exp], [Exp]))
-> Int
-> Int
-> Exp
-> Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp, Exp, Exp)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) g a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
g
-> [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e, g, g)
isOk ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
destshape) [Exp]
destshape ([Exp], [Exp]) -> ([Exp], [Exp])
forall b a. (b, a) -> (a, b)
swap Int
r1 Int
r2 Exp
dest_offset Exp
src_offset
  | Just Exp
dest_offset <- IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
destIxFun Exp
bt_size,
    Just (Exp
src_offset, [(Int, Exp)]
perm_and_srcshape) <- IxFun Exp -> Exp -> Maybe (Exp, [(Int, Exp)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun Exp
srcIxFun Exp
bt_size,
    ([Int]
perm, [Exp]
srcshape) <- [(Int, Exp)] -> ([Int], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, Exp)]
perm_and_srcshape,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Exp
-> [Exp]
-> (([Exp], [Exp]) -> ([Exp], [Exp]))
-> Int
-> Int
-> Exp
-> Exp
-> Maybe (Exp, Exp, Exp, Exp, Exp, Exp, Exp)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) g a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
g
-> [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e, g, g)
isOk ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
srcshape) [Exp]
srcshape ([Exp], [Exp]) -> ([Exp], [Exp])
forall a. a -> a
id Int
r1 Int
r2 Exp
dest_offset Exp
src_offset
  | Bool
otherwise =
      Maybe (Exp, Exp, Exp, Exp, Exp, Exp, Exp)
forall a. Maybe a
Nothing
  where bt_size :: Exp
bt_size = PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
        swap :: (b, a) -> (a, b)
swap (b
x,a
y) = (a
y,b
x)

        isOk :: g
-> [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e, g, g)
isOk g
elems [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
          let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall (t :: * -> *) (t :: * -> *) a b c.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
          (a, b, c, d, e, g, g) -> m (a, b, c, d, e, g, g)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
dest_offset, b
src_offset,
                  c
num_arrays, d
size_x, e
size_y,
                  g
elems, g
elems)

        getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
          let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
              (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
          in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)