{-# LANGUAGE QuasiQuotes #-}
module Futhark.CodeGen.ImpGen.GPU.ToOpenCL
( kernelsToOpenCL,
kernelsToCUDA,
)
where
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Fun qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.GPU hiding (Program)
import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import Futhark.CodeGen.ImpCode.OpenCL qualified as ImpOpenCL
import Futhark.CodeGen.RTS.C (atomicsH, halfH)
import Futhark.Error (compilerLimitationS)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)
kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetCUDA
kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetOpenCL
translateGPU ::
KernelTarget ->
ImpGPU.Program ->
ImpOpenCL.Program
translateGPU :: KernelTarget -> Program -> Program
translateGPU KernelTarget
target Program
prog =
let ( Definitions OpenCL
prog',
ToOpenCL Map Name (KernelSafety, Text)
kernels Map Name (Definition, Text)
device_funs Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures
) =
(forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Program -> Env
envFromProg Program
prog) forall a b. (a -> b) -> a -> b
$ do
let ImpGPU.Definitions
OpaqueTypes
types
(ImpGPU.Constants [Param]
ps Code HostOp
consts)
(ImpGPU.Functions [(Name, Function HostOp)]
funs) = Program
prog
Code OpenCL
consts' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
[(Name, FunctionT OpenCL)]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
(Name
fname,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Function HostOp
fun
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
OpaqueTypes
types
(forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
(forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')
([Definition]
device_prototypes, [Text]
device_defs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (Definition, Text)
device_funs
kernels' :: Map Name KernelSafety
kernels' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a b. (a, b) -> a
fst Map Name (KernelSafety, Text)
kernels
opencl_code :: Text
opencl_code = [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (KernelSafety, Text)
kernels
opencl_prelude :: Text
opencl_prelude =
[Text] -> Text
T.unlines
[ KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
target Set PrimType
used_types,
[Definition] -> Text
definitionsText [Definition]
device_prototypes,
[Text] -> Text
T.unlines [Text]
device_defs
]
in Text
-> Text
-> Map Name KernelSafety
-> [PrimType]
-> Map Name SizeClass
-> [FailureMsg]
-> Definitions OpenCL
-> Program
ImpOpenCL.Program
Text
opencl_code
Text
opencl_prelude
Map Name KernelSafety
kernels'
(forall a. Set a -> [a]
S.toList Set PrimType
used_types)
(Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes)
[FailureMsg]
failures
Definitions OpenCL
prog'
where
genPrelude :: KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> Text
genOpenClPrelude
genPrelude KernelTarget
TargetCUDA = forall a b. a -> b -> a
const Text
genCUDAPrelude
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
where
known :: [Name]
known = forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (forall a. (a -> Bool) -> [a] -> [a]
filter ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
clean SizeClass
s = SizeClass
s
pointerQuals :: Monad m => String -> m [C.TypeQual]
pointerQuals :: forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
"global" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__global|]
pointerQuals EncodedString
"local" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__local|]
pointerQuals EncodedString
"private" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__private|]
pointerQuals EncodedString
"constant" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__constant|]
pointerQuals EncodedString
"write_only" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__write_only|]
pointerQuals EncodedString
"read_only" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__read_only|]
pointerQuals EncodedString
"kernel" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__kernel|]
pointerQuals EncodedString
s = forall a. HasCallStack => EncodedString -> a
error forall a b. (a -> b) -> a -> b
$ EncodedString
"'" forall a. [a] -> [a] -> [a]
++ EncodedString
s forall a. [a] -> [a] -> [a]
++ EncodedString
"' is not an OpenCL kernel address space."
type LocalMemoryUse = (VName, Count Bytes Exp)
data KernelState = KernelState
{ KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse],
KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
KernelState -> Int
kernelNextSync :: Int,
KernelState -> Bool
kernelSyncPending :: Bool,
KernelState -> Bool
kernelHasBarriers :: Bool
}
newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False
errorLabel :: KernelState -> String
errorLabel :: KernelState -> EncodedString
errorLabel = (EncodedString
"error_" ++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> EncodedString
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync
data ToOpenCL = ToOpenCL
{ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU :: M.Map KernelName (KernelSafety, T.Text),
ToOpenCL -> Map Name (Definition, Text)
clDevFuns :: M.Map Name (C.Definition, T.Text),
ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass,
ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg]
}
initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map Name (KernelSafety, Text)
-> Map Name (Definition, Text)
-> Set PrimType
-> Map Name SizeClass
-> [FailureMsg]
-> ToOpenCL
ToOpenCL forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
data Env = Env
{ Env -> Functions HostOp
envFuns :: ImpGPU.Functions ImpGPU.HostOp,
Env -> Set Name
envFunsMayFail :: S.Set Name
}
codeMayFail :: (a -> Bool) -> ImpGPU.Code a -> Bool
codeMayFail :: forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
_ (Assert {}) = Bool
True
codeMayFail a -> Bool
f (Op a
x) = a -> Bool
f a
x
codeMayFail a -> Bool
f (Code a
x :>>: Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (For VName
_ Exp
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (While TExp Bool
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (If TExp Bool
_ Code a
x Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (Comment Text
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
_ Code a
_ = Bool
False
hostOpMayFail :: ImpGPU.HostOp -> Bool
hostOpMayFail :: HostOp -> Bool
hostOpMayFail (CallKernel Kernel
k) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail KernelOp -> Bool
kernelOpMayFail forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
hostOpMayFail HostOp
_ = Bool
False
kernelOpMayFail :: ImpGPU.KernelOp -> Bool
kernelOpMayFail :: KernelOp -> Bool
kernelOpMayFail = forall a b. a -> b -> a
const Bool
False
funsMayFail :: M.Map Name (S.Set Name) -> ImpGPU.Functions ImpGPU.HostOp -> S.Set Name
funsMayFail :: Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg (Functions [(Name, Function HostOp)]
funs) =
forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {b}. (Name, b) -> Bool
mayFail [(Name, Function HostOp)]
funs
where
base_mayfail :: [Name]
base_mayfail =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. (a -> Bool) -> Code a -> Bool
codeMayFail HostOp -> Bool
hostOpMayFail forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FunctionT a -> Code a
ImpGPU.functionBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, Function HostOp)]
funs
mayFail :: (Name, b) -> Bool
mayFail (Name
fname, b
_) =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
base_mayfail) forall a b. (a -> b) -> a -> b
$ Name
fname forall a. a -> [a] -> [a]
: forall a. Set a -> [a]
S.toList (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg)
envFromProg :: ImpGPU.Program -> Env
envFromProg :: Program -> Env
envFromProg Program
prog = Functions HostOp -> Set Name -> Env
Env Functions HostOp
funs (Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg Functions HostOp
funs)
where
funs :: Functions HostOp
funs = forall a. Definitions a -> Functions a
defFuns Program
prog
cg :: Map Name (Set Name)
cg = forall a. (a -> Set Name) -> Functions a -> Map Name (Set Name)
ImpGPU.callGraph HostOp -> Set Name
calledInHostOp Functions HostOp
funs
lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp)
lookupFunction :: Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Functions a -> [(Name, Function a)]
unFunctions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Functions HostOp
envFuns
functionMayFail :: Name -> Env -> Bool
functionMayFail :: Name -> Env -> Bool
functionMayFail Name
fname = forall a. Ord a => a -> Set a -> Bool
S.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set Name
envFunsMayFail
type OnKernelM = ReaderT Env (State ToOpenCL)
addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map Name SizeClass
clSizes = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s}
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpGPU.GetSize VName
v Name
key SizeClass
size_class) = do
Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpGPU.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpGPU.GetSizeMax VName
v SizeClass
size_class) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class
genGPUCode ::
Env ->
OpsMode ->
KernelCode ->
[FailureMsg] ->
GC.CompilerM KernelOp KernelState a ->
(a, GC.CompilerState KernelState)
genGPUCode :: forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
mode Code KernelOp
body [FailureMsg]
failures =
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
(Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body)
VNameSource
blankNameSource
([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)
generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
generateDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
device_func = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> [Param]
functionInput Function KernelOp
device_func) forall {a}. a
bad
Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
[FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
let ((Definition, Func)
func, KernelState
kstate) =
if Name -> Env -> Bool
functionMayFail Name
fname Env
env
then
let params :: [Param]
params =
[ [C.cparam|__global int *global_failure|],
[C.cparam|__global typename int64_t *global_failure_args|]
]
((Definition, Func)
f, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun forall a. Monoid a => a
mempty [Param]
params (Name
fname, Function KernelOp
device_func)
in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
else
let ((Definition, Func)
f, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
forall op s.
[BlockItem]
-> (Name, Function op) -> CompilerM op s (Definition, Func)
GC.compileVoidFun forall a. Monoid a => a
mempty (Name
fname, Function KernelOp
device_func)
in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
ToOpenCL
s
{ clUsedTypes :: Set PrimType
clUsedTypes = Code KernelOp -> Set PrimType
typesInCode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
clDevFuns :: Map Name (Definition, Text)
clDevFuns = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Func -> Text
funcText (Definition, Func)
func) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (Definition, Text)
clDevFuns ToOpenCL
s,
clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
}
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func
where
memParam :: Param -> Bool
memParam MemParam {} = Bool
True
memParam ScalarParam {} = Bool
False
bad :: a
bad = forall a. EncodedString -> a
compilerLimitationS EncodedString
"Cannot generate GPU functions that use arrays."
ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
ensureDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
host_func = do
Bool
exists <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Bool
M.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map Name (Definition, Text)
clDevFuns
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
host_func
calledInHostOp :: HostOp -> S.Set Name
calledInHostOp :: HostOp -> Set Name
calledInHostOp (CallKernel Kernel
k) = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
calledInHostOp HostOp
_ = forall a. Monoid a => a
mempty
calledInKernelOp :: KernelOp -> S.Set Name
calledInKernelOp :: KernelOp -> Set Name
calledInKernelOp = forall a b. a -> b -> a
const forall a. Monoid a => a
mempty
ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns Code KernelOp
code = do
let called :: Set Name
called = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp Code KernelOp
code
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Set a -> [a]
S.toList Set Name
called) forall a b. (a -> b) -> a -> b
$ \Name
fname -> do
Maybe (Function HostOp)
def <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname
case Maybe (Function HostOp)
def of
Just Function HostOp
host_func -> do
let device_func :: Function KernelOp
device_func = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
device_func
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Name
fname
Maybe (Function HostOp)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
where
bad :: a
bad = forall a. EncodedString -> a
compilerLimitationS EncodedString
"Cannot generate GPU functions that contain parallelism."
toDevice :: HostOp -> KernelOp
toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = forall {a}. a
bad
isConst :: GroupDim -> Maybe T.Text
isConst :: GroupDim -> Maybe Text
isConst (Left (ValueExp (IntValue IntValue
x))) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
intToInt64 IntValue
x
isConst (Right (SizeConst Name
v)) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ EncodedString -> Text
T.pack forall a b. (a -> b) -> a -> b
$ EncodedString -> EncodedString
zEncodeString forall a b. (a -> b) -> a -> b
$ Name -> EncodedString
nameToString Name
v
isConst (Right (SizeMaxConst SizeClass
size_class)) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ EncodedString -> Text
T.pack forall a b. (a -> b) -> a -> b
$ EncodedString
"max_" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> EncodedString
prettyString SizeClass
size_class
isConst GroupDim
_ = forall a. Maybe a
Nothing
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
kernel = do
[Name]
called <- Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
[FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
KernelMode (Kernel -> Code KernelOp
kernelBody Kernel
kernel) [FailureMsg]
failures forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ do
[BlockItem]
body <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ forall op s. Code op -> CompilerM op s ()
GC.compileCode forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op s. CompilerM op s [BlockItem]
GC.declAllocatedMem
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item [BlockItem]
body
kstate :: KernelState
kstate = forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate
([Maybe KernelArg]
local_memory_args, [Maybe Param]
local_memory_params, [BlockItem]
local_memory_init) =
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> a
evalState (VNameSource
blankNameSource :: VNameSource) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *}.
MonadFreshNames m =>
KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
target) forall a b. (a -> b) -> a -> b
$
KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate
([Param]
perm_params, [BlockItem]
block_dim_init) =
case (KernelTarget
target, [Exp]
num_groups) of
(KernelTarget
TargetCUDA, [Exp
_, Exp
_, Exp
_]) ->
( [ [C.cparam|const int block_dim0|],
[C.cparam|const int block_dim1|],
[C.cparam|const int block_dim2|]
],
forall a. Monoid a => a
mempty
)
(KernelTarget, [Exp])
_ ->
( forall a. Monoid a => a
mempty,
[ [C.citem|const int block_dim0 = 0;|],
[C.citem|const int block_dim1 = 1;|],
[C.citem|const int block_dim2 = 2;|]
]
)
([BlockItem]
const_defs, [BlockItem]
const_undefs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (BlockItem, BlockItem)
constDef forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel
let ([Param]
use_params, [[BlockItem]]
unpack_params) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (Param, [BlockItem])
useAsParam forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel
let (KernelSafety
safety, [BlockItem]
error_init)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
called =
( KernelSafety
SafetyFull,
[C.citems|volatile __local bool local_failure;
// Harmless for all threads to write this.
local_failure = false;|]
)
| forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
then (KernelSafety
SafetyNone, [])
else
( KernelSafety
SafetyCheap,
[C.citems|if (*global_failure >= 0) { return; }|]
)
| Bool
otherwise =
if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
then
( KernelSafety
SafetyFull,
[C.citems|if (*global_failure >= 0) { return; }|]
)
else
( KernelSafety
SafetyFull,
[C.citems|
volatile __local bool local_failure;
if (failure_is_an_option) {
int failed = *global_failure >= 0;
if (failed) {
return;
}
}
// All threads write this value - it looks like CUDA has a compiler bug otherwise.
local_failure = false;
barrier(CLK_LOCAL_MEM_FENCE);
|]
)
failure_params :: [Param]
failure_params =
[ [C.cparam|__global int *global_failure|],
[C.cparam|int failure_is_an_option|],
[C.cparam|__global typename int64_t *global_failure_args|]
]
params :: [Param]
params =
[Param]
perm_params
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
forall a. [a] -> [a] -> [a]
++ forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
local_memory_params
forall a. [a] -> [a] -> [a]
++ [Param]
use_params
attribute :: Text
attribute =
case (KernelTarget
target, forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM GroupDim -> Maybe Text
isConst forall a b. (a -> b) -> a -> b
$ Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel) of
(KernelTarget
TargetOpenCL, Just [Text
x, Text
y, Text
z]) ->
Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Text
z) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
(KernelTarget
TargetOpenCL, Just [Text
x, Text
y]) ->
Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
(KernelTarget
TargetOpenCL, Just [Text
x]) ->
Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Int
1 :: Int, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
(KernelTarget, Maybe [Text])
_ -> Text
""
kernel_fun :: Text
kernel_fun =
Text
attribute
forall a. Semigroup a => a -> a -> a
<> Func -> Text
funcText
[C.cfun|__kernel void $id:name ($params:params) {
$items:(mconcat unpack_params)
$items:const_defs
$items:block_dim_init
$items:local_memory_init
$items:error_init
$items:kernel_body
$id:(errorLabel kstate): return;
$items:const_undefs
}|]
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
ToOpenCL
s
{ clGPU :: Map Name (KernelSafety, Text)
clGPU = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name (KernelSafety
safety, Text
kernel_fun) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU ToOpenCL
s,
clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
}
let args :: [KernelArg]
args = forall a. [Maybe a] -> [a]
catMaybes [Maybe KernelArg]
local_memory_args forall a. [a] -> [a] -> [a]
++ Kernel -> [KernelArg]
kernelArgs Kernel
kernel
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelSafety
-> Name -> [KernelArg] -> [Exp] -> [GroupDim] -> OpenCL
LaunchKernel KernelSafety
safety Name
name [KernelArg]
args [Exp]
num_groups [GroupDim]
group_size
where
name :: Name
name = Kernel -> Name
kernelName Kernel
kernel
num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
group_size :: [GroupDim]
group_size = Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel
prepareLocalMemory :: KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
TargetOpenCL (VName
mem, Count Bytes Exp
size) = do
VName
mem_aligned <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> EncodedString
baseString VName
mem forall a. [a] -> [a] -> [a]
++ EncodedString
"_aligned"
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
forall a. a -> Maybe a
Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
[C.citem|__local volatile unsigned char* restrict $id:mem = (__local volatile unsigned char*) $id:mem_aligned;|]
)
prepareLocalMemory KernelTarget
TargetCUDA (VName
mem, Count Bytes Exp
size) = do
VName
param <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> EncodedString
baseString VName
mem forall a. [a] -> [a] -> [a]
++ EncodedString
"_offset"
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
forall a. a -> Maybe a
Just [C.cparam|uint $id:param|],
[C.citem|volatile $ty:defaultMemBlockType $id:mem = &shared_mem[$id:param];|]
)
useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem])
useAsParam :: KernelUse -> Maybe (Param, [BlockItem])
useAsParam (ScalarUse VName
name PrimType
pt) = do
let name_bits :: EncodedString
name_bits = EncodedString -> EncodedString
zEncodeString (forall a. Pretty a => a -> EncodedString
prettyString VName
name) forall a. Semigroup a => a -> a -> a
<> EncodedString
"_bits"
ctp :: Type
ctp = case PrimType
pt of
PrimType
Bool -> [C.cty|unsigned char|]
PrimType
Unit -> [C.cty|unsigned char|]
PrimType
_ -> PrimType -> Type
primStorageType PrimType
pt
if Type
ctp forall a. Eq a => a -> a -> Bool
== PrimType -> Type
primTypeToCType PrimType
pt
then forall a. a -> Maybe a
Just ([C.cparam|$ty:ctp $id:name|], [])
else
let name_bits_e :: Exp
name_bits_e = [C.cexp|$id:name_bits|]
in forall a. a -> Maybe a
Just
( [C.cparam|$ty:ctp $id:name_bits|],
[[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]]
)
useAsParam (MemoryUse VName
name) =
forall a. a -> Maybe a
Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], [])
useAsParam ConstUse {} =
forall a. Maybe a
Nothing
constDef :: KernelUse -> Maybe (C.BlockItem, C.BlockItem)
constDef :: KernelUse -> Maybe (BlockItem, BlockItem)
constDef (ConstUse VName
v KernelConstExp
e) =
forall a. a -> Maybe a
Just
( [C.citem|$escstm:(T.unpack def)|],
[C.citem|$escstm:(T.unpack undef)|]
)
where
e' :: Exp
e' = KernelConstExp -> Exp
compilePrimExp KernelConstExp
e
def :: Text
def = Text
"#define " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<> Text
" (" forall a. Semigroup a => a -> a -> a
<> Exp -> Text
expText Exp
e' forall a. Semigroup a => a -> a -> a
<> Text
")"
undef :: Text
undef = Text
"#undef " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty)
constDef KernelUse
_ = forall a. Maybe a
Nothing
genOpenClPrelude :: S.Set PrimType -> T.Text
genOpenClPrelude :: Set PrimType -> Text
genOpenClPrelude Set PrimType
ts =
[untrimming|
// Clang-based OpenCL implementations need this for 'static' to work.
#ifdef cl_clang_storage_class_specifiers
#pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable
#endif
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
$enable_f64
// Some OpenCL programs dislike empty progams, or programs with no kernels.
// Declare a dummy kernel to ensure they remain our friends.
__kernel void dummy_kernel(__global unsigned char *dummy, int n)
{
const int thread_gid = get_global_id(0);
if (thread_gid >= n) return;
}
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable
typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long int64_t;
typedef uchar uint8_t;
typedef ushort uint16_t;
typedef uint uint32_t;
typedef ulong uint64_t;
// NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we
// use inline assembly if we detect we are on an NVIDIA GPU.
#ifdef cl_nv_pragma_unroll
static inline void mem_fence_global() {
asm("membar.gl;");
}
#else
static inline void mem_fence_global() {
mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
}
#endif
static inline void mem_fence_local() {
mem_fence(CLK_LOCAL_MEM_FENCE);
}
|]
forall a. Semigroup a => a -> a -> a
<> Text
halfH
forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
where
enable_f64 :: Text
enable_f64
| FloatType -> PrimType
FloatType FloatType
Float64 forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts =
[untrimming|
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#define FUTHARK_F64_ENABLED
|]
| Bool
otherwise = forall a. Monoid a => a
mempty
genCUDAPrelude :: T.Text
genCUDAPrelude :: Text
genCUDAPrelude =
[untrimming|
#define FUTHARK_CUDA
#define FUTHARK_F64_ENABLED
typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long long int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef uint8_t uchar;
typedef uint16_t ushort;
typedef uint32_t uint;
typedef uint64_t ulong;
#define __kernel extern "C" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)
#define __global
#define __local
#define __private
#define __constant
#define __write_only
#define __read_only
static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
switch (d) {
case 0: d = block_dim0; break;
case 1: d = block_dim1; break;
case 2: d = block_dim2; break;
}
switch (d) {
case 0: return blockIdx.x;
case 1: return blockIdx.y;
case 2: return blockIdx.z;
default: return 0;
}
}
#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)
static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
switch (d) {
case 0: d = block_dim0; break;
case 1: d = block_dim1; break;
case 2: d = block_dim2; break;
}
switch(d) {
case 0: return gridDim.x;
case 1: return gridDim.y;
case 2: return gridDim.z;
default: return 0;
}
}
#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)
static inline int get_local_id(int d) {
switch (d) {
case 0: return threadIdx.x;
case 1: return threadIdx.y;
case 2: return threadIdx.z;
default: return 0;
}
}
static inline int get_local_size(int d) {
switch (d) {
case 0: return blockDim.x;
case 1: return blockDim.y;
case 2: return blockDim.z;
default: return 0;
}
}
#define CLK_LOCAL_MEM_FENCE 1
#define CLK_GLOBAL_MEM_FENCE 2
static inline void barrier(int x) {
__syncthreads();
}
static inline void mem_fence_local() {
__threadfence_block();
}
static inline void mem_fence_global() {
__threadfence();
}
#define NAN (0.0/0.0)
#define INFINITY (1.0/0.0)
extern volatile __shared__ unsigned char shared_mem[];
|]
forall a. Semigroup a => a -> a -> a
<> Text
halfH
forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp :: KernelConstExp -> Exp
compilePrimExp KernelConstExp
e = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GC.compilePrimExp forall {f :: * -> *}. Applicative f => KernelConst -> f Exp
compileKernelConst KernelConstExp
e
where
compileKernelConst :: KernelConst -> f Exp
compileKernelConst (SizeConst Name
key) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(zEncodeString (prettyString key))|]
compileKernelConst (SizeMaxConst SizeClass
size_class) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:("max_" <> prettyString size_class)|]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
where
useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
useToArg (ScalarUse VName
v PrimType
pt) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt) PrimType
pt
useToArg ConstUse {} = forall a. Maybe a
Nothing
nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState EncodedString
nextErrorLabel =
KernelState -> EncodedString
errorLabel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s forall a. Num a => a -> a -> a
+ Int
1}
pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}
hasCommunication :: ImpGPU.KernelCode -> Bool
hasCommunication :: Code KernelOp -> Bool
hasCommunication = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
where
communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
communicates Barrier {} = Bool
True
communicates KernelOp
_ = Bool
False
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c== :: OpsMode -> OpsMode -> Bool
Eq)
inKernelOperations ::
Env ->
OpsMode ->
ImpGPU.KernelCode ->
GC.Operations KernelOp KernelState
inKernelOperations :: Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body =
GC.Operations
{ opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = forall {m :: * -> *}. Monad m => EncodedString -> m Type
kernelMemoryType,
opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = forall {op} {s}. WriteScalar op s
kernelWriteScalar,
opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = forall {op} {s}. ReadScalar op s
kernelReadScalar,
opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
opsDeallocate :: Deallocate KernelOp KernelState
GC.opsDeallocate = Deallocate KernelOp KernelState
cannotDeallocate,
opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
opsStaticArray :: StaticArray KernelOp KernelState
GC.opsStaticArray = StaticArray KernelOp KernelState
noStaticArrays,
opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = forall {a}.
ToIdent a =>
[a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel,
opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = forall a. Monoid a => a
mempty
}
where
has_communication :: Bool
has_communication = Code KernelOp -> Bool
hasCommunication Code KernelOp
body
fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]
kernelOps :: GC.OpCompiler KernelOp KernelState
kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_group_id($int:i);|]
kernelOps (GetLocalId VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
kernelOps (GetLocalSize VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
kernelOps (GetLockstepWidth VName
v) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
kernelOps (Barrier Fence
f) = do
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
kernelOps (MemFence Fence
FenceLocal) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
kernelOps (MemFence Fence
FenceGlobal) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
kernelOps (LocalAlloc VName
name Count Bytes (TExp Int64)
size) = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> EncodedString
prettyString VName
name forall a. [a] -> [a] -> [a]
++ EncodedString
"_backing"
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
KernelState
s {kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Count Bytes (TExp Int64)
size) forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s}
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|]
kernelOps (ErrorSync Fence
f) = do
EncodedString
label <- CompilerM KernelOp KernelState EncodedString
nextErrorLabel
Bool
pending <- KernelState -> Bool
kernelSyncPending forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending forall a b. (a -> b) -> a -> b
$ do
Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
CompilerM KernelOp KernelState ()
incErrorLabel
kernelOps (Atomic Space
space AtomicOp
aop) = forall {op} {s}. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop
atomicCast :: Space -> Type -> m Type
atomicCast Space
s Type
t = do
let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
[TypeQual]
quals <- case Space
s of
Space EncodedString
sid -> forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
sid
Space
_ -> forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
"global"
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volatile++quals) $ty:t|]
atomicSpace :: Space -> EncodedString
atomicSpace (Space EncodedString
sid) = EncodedString
sid
atomicSpace Space
_ = EncodedString
"global"
doAtomic :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val EncodedString
op Type
ty = do
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
where
op' :: EncodedString
op' = EncodedString
op forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
prettyString p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s
doAtomicCmpXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
cmp Exp
val Type
ty = do
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
cmp' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
where
op :: EncodedString
op = EncodedString
"atomic_cmpxchg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
prettyString p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s
doAtomicXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val Type
ty = do
Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
where
op :: EncodedString
op = EncodedString
"atomic_chg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
prettyString p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s
atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_add" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_fadd" [C.cty|double|]
atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smax" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smin" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umax" [C.cty|unsigned int64_t|]
atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umin" [C.cty|unsigned int64_t|]
atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_and" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_or" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_xor" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|typename int64_t|]
atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_add" [C.cty|int|]
atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_fadd" [C.cty|float|]
atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smax" [C.cty|int|]
atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smin" [C.cty|int|]
atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umax" [C.cty|unsigned int|]
atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umin" [C.cty|unsigned int|]
atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_and" [C.cty|int|]
atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_or" [C.cty|int|]
atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_xor" [C.cty|int|]
atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|int|]
atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|int|]
cannotAllocate :: GC.Allocate KernelOp KernelState
cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot allocate memory in kernel"
cannotDeallocate :: GC.Deallocate KernelOp KernelState
cannotDeallocate :: Deallocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot deallocate memory in kernel"
copyInKernel :: GC.Copy KernelOp KernelState
copyInKernel :: Copy KernelOp KernelState
copyInKernel CopyBarrier
_ Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot bulk copy in kernel."
noStaticArrays :: GC.StaticArray KernelOp KernelState
noStaticArrays :: StaticArray KernelOp KernelState
noStaticArrays VName
_ EncodedString
_ PrimType
_ ArrayContents
_ =
forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot create static array in kernel."
kernelMemoryType :: EncodedString -> m Type
kernelMemoryType EncodedString
space = do
[TypeQual]
quals <- forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:quals $ty:defaultMemBlockType|]
kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
forall op s. PointerQuals op s -> WriteScalar op s
GC.writeScalarPointerWithQuals forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals
kernelReadScalar :: ReadScalar op s
kernelReadScalar =
forall op s. PointerQuals op s -> ReadScalar op s
GC.readScalarPointerWithQuals forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals
whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
EncodedString
label <- CompilerM KernelOp KernelState EncodedString
nextErrorLabel
Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
if Bool
has_communication
then [C.citems|local_failure = true; goto $id:label;|]
else
if OpsMode
mode forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
then [C.citems|return 1;|]
else [C.citems|return;|]
callInKernel :: [a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel [a]
dests Name
fname [Exp]
args
| Name -> Env -> Bool
functionMayFail Name
fname Env
env = do
let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
args' :: [Exp]
args' =
[C.cexp|global_failure|]
forall a. a -> [a] -> [a]
: [C.cexp|global_failure_args|]
forall a. a -> [a] -> [a]
: [Exp]
out_args
forall a. [a] -> [a] -> [a]
++ [Exp]
args
[BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]
| Bool
otherwise = do
let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
args' :: [Exp]
args' = [Exp]
out_args forall a. [a] -> [a] -> [a]
++ [Exp]
args
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|$id:(funName fname)($args:args');|]
errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) EncodedString
backtrace = do
Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> EncodedString -> FailureMsg
FailureMsg ErrorMsg Exp
msg EncodedString
backtrace]}
let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
setArgs a
i (ErrorVal PrimType
_ Exp
x : [ErrorMsgPart Exp]
parts') = do
Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
[Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] forall a. a -> [a] -> [a]
: [Stm]
stms
[Stm]
argstms <- forall {a} {op} {s}.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts
[BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
forall op s. Stm -> CompilerM op s ()
GC.stm
[C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
{ $stms:argstms; }
$items:what_next
}|]
typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = Code KernelOp -> Set PrimType
typesInCode forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
typesInCode :: ImpGPU.KernelCode -> S.Set PrimType
typesInCode :: Code KernelOp -> Set PrimType
typesInCode Code KernelOp
Skip = forall a. Monoid a => a
mempty
typesInCode (Code KernelOp
c1 :>>: Code KernelOp
c2) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (For VName
_ Exp
e Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (While (TPrimExp Exp
e) Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode DeclareMem {} = forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ Space
_ PrimType
t ArrayContents
_) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = forall a. Monoid a => a
mempty
typesInCode (Copy PrimType
_ VName
_ (Count (TPrimExp Exp
e1)) Space
_ VName
_ (Count (TPrimExp Exp
e2)) Space
_ (Count (TPrimExp Exp
e3))) =
Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e3
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (Read VName
_ VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_) =
Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
where
typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = forall a. Monoid a => a
mempty
typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) Code KernelOp
c1 Code KernelOp
c2) =
Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment Text
_ Code KernelOp
c) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (DebugPrint EncodedString
_ Maybe Exp
v) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode (TracePrint ErrorMsg Exp
msg) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set PrimType
typesInExp ErrorMsg Exp
msg
typesInCode Op {} = forall a. Monoid a => a
mempty
typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
where
(PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp EncodedString
_ [Exp]
args PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp LeafExp {} = forall a. Monoid a => a
mempty