{-# LANGUAGE QuasiQuotes #-}
module Futhark.CodeGen.ImpGen.GPU.ToOpenCL
( kernelsToOpenCL,
kernelsToCUDA,
kernelsToHIP,
)
where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Foldable (toList)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Fun qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.GPU hiding (Program)
import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import Futhark.CodeGen.ImpCode.OpenCL qualified as ImpOpenCL
import Futhark.CodeGen.RTS.C (atomicsH, halfH)
import Futhark.CodeGen.RTS.CUDA (preludeCU)
import Futhark.CodeGen.RTS.OpenCL (copyCL, preludeCL, transposeCL)
import Futhark.Error (compilerLimitationS)
import Futhark.MonadFreshNames
import Futhark.Util (mapAccumLM, zEncodeText)
import Futhark.Util.IntegralExp (rem)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)
import Prelude hiding (rem)
kernelsToHIP :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToHIP :: Program -> Program
kernelsToHIP = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetHIP
kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetCUDA
kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetOpenCL
translateGPU ::
KernelTarget ->
ImpGPU.Program ->
ImpOpenCL.Program
translateGPU :: KernelTarget -> Program -> Program
translateGPU KernelTarget
target Program
prog =
let env :: Env
env = Program -> Env
envFromProg Program
prog
( Definitions OpenCL
prog',
ToOpenCL Map Name (KernelSafety, Text)
kernels Map Name (Definition, Text)
device_funs Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures [(Name, KernelConstExp)]
constants
) =
(forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Env
env) forall a b. (a -> b) -> a -> b
$ do
let ImpGPU.Definitions
OpaqueTypes
types
(ImpGPU.Constants [Param]
ps Code HostOp
consts)
(ImpGPU.Functions [(Name, Function HostOp)]
funs) = Program
prog
Code OpenCL
consts' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
[(Name, FunctionT OpenCL)]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
(Name
fname,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Function HostOp
fun
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
OpaqueTypes
types
(forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
(forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')
([Definition]
device_prototypes, [Text]
device_defs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (Definition, Text)
device_funs
kernels' :: Map Name KernelSafety
kernels' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a b. (a, b) -> a
fst Map Name (KernelSafety, Text)
kernels
opencl_code :: Text
opencl_code = [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (KernelSafety, Text)
kernels
opencl_prelude :: Text
opencl_prelude =
[Text] -> Text
T.unlines
[ KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
target Set PrimType
used_types,
[Definition] -> Text
definitionsText [Definition]
device_prototypes,
[Text] -> Text
T.unlines [Text]
device_defs
]
in ImpOpenCL.Program
{ openClProgram :: Text
openClProgram = Text
opencl_code,
openClPrelude :: Text
openClPrelude = Text
opencl_prelude,
openClMacroDefs :: [(Name, KernelConstExp)]
openClMacroDefs = [(Name, KernelConstExp)]
constants,
openClKernelNames :: Map Name KernelSafety
openClKernelNames = Map Name KernelSafety
kernels',
openClUsedTypes :: [PrimType]
openClUsedTypes = forall a. Set a -> [a]
S.toList Set PrimType
used_types,
openClParams :: ParamMap
openClParams = Env -> Definitions OpenCL -> Map Name SizeClass -> ParamMap
findParamUsers Env
env Definitions OpenCL
prog' (Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes),
openClFailures :: [FailureMsg]
openClFailures = [FailureMsg]
failures,
hostDefinitions :: Definitions OpenCL
hostDefinitions = Definitions OpenCL
prog'
}
where
genPrelude :: KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> Text
genOpenClPrelude
genPrelude KernelTarget
TargetCUDA = forall a b. a -> b -> a
const Text
genCUDAPrelude
genPrelude KernelTarget
TargetHIP = forall a b. a -> b -> a
const Text
genHIPPrelude
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
where
known :: [Name]
known = forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (forall a. (a -> Bool) -> [a] -> [a]
filter ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
clean SizeClass
s = SizeClass
s
findParamUsers ::
Env ->
Definitions ImpOpenCL.OpenCL ->
M.Map Name SizeClass ->
ParamMap
findParamUsers :: Env -> Definitions OpenCL -> Map Name SizeClass -> ParamMap
findParamUsers Env
env Definitions OpenCL
defs = forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey forall {a}. Name -> a -> (a, Set Name)
onParam
where
cg :: Map Name (Set Name)
cg = Env -> Map Name (Set Name)
envCallGraph Env
env
getSize :: OpenCL -> Maybe Name
getSize (ImpOpenCL.GetSize VName
_ Name
v) = forall a. a -> Maybe a
Just Name
v
getSize (ImpOpenCL.CmpSizeLe VName
_ Name
v Exp
_) = forall a. a -> Maybe a
Just Name
v
getSize (ImpOpenCL.GetSizeMax {}) = forall a. Maybe a
Nothing
getSize (ImpOpenCL.LaunchKernel {}) = forall a. Maybe a
Nothing
directUseInFun :: FunctionT OpenCL -> [Name]
directUseInFun FunctionT OpenCL
fun = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe OpenCL -> Maybe Name
getSize forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody FunctionT OpenCL
fun
direct_uses :: [(Name, [Name])]
direct_uses = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second FunctionT OpenCL -> [Name]
directUseInFun) forall a b. (a -> b) -> a -> b
$ forall a. Functions a -> [(Name, Function a)]
unFunctions forall a b. (a -> b) -> a -> b
$ forall a. Definitions a -> Functions a
defFuns Definitions OpenCL
defs
calledBy :: Name -> Set Name
calledBy Name
fname = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg
indirectUseInFun :: Name -> (Name, [Name])
indirectUseInFun Name
fname =
( Name
fname,
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> Set a -> Bool
`S.member` Name -> Set Name
calledBy Name
fname) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Name, [Name])]
direct_uses
)
indirect_uses :: [(Name, [Name])]
indirect_uses = [(Name, [Name])]
direct_uses forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map (Name -> (Name, [Name])
indirectUseInFun forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Name, [Name])]
direct_uses
onParam :: Name -> a -> (a, Set Name)
onParam Name
k a
c = (a
c, forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((Name
k `elem`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, [Name])]
indirect_uses)
pointerQuals :: String -> [C.TypeQual]
pointerQuals :: [Char] -> [TypeQual]
pointerQuals [Char]
"global" = [C.ctyquals|__global|]
pointerQuals [Char]
"local" = [C.ctyquals|__local|]
pointerQuals [Char]
"private" = [C.ctyquals|__private|]
pointerQuals [Char]
"constant" = [C.ctyquals|__constant|]
pointerQuals [Char]
"write_only" = [C.ctyquals|__write_only|]
pointerQuals [Char]
"read_only" = [C.ctyquals|__read_only|]
pointerQuals [Char]
"kernel" = [C.ctyquals|__kernel|]
pointerQuals [Char]
"device" = [Char] -> [TypeQual]
pointerQuals [Char]
"global"
pointerQuals [Char]
s = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"'" forall a. [a] -> [a] -> [a]
++ [Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
"' is not an OpenCL kernel address space."
type LocalMemoryUse = (VName, Count Bytes (TExp Int64))
data KernelState = KernelState
{ KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse],
KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
KernelState -> Int
kernelNextSync :: Int,
KernelState -> Bool
kernelSyncPending :: Bool,
KernelState -> Bool
kernelHasBarriers :: Bool
}
newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False
errorLabel :: KernelState -> String
errorLabel :: KernelState -> [Char]
errorLabel = ([Char]
"error_" ++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync
data ToOpenCL = ToOpenCL
{ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU :: M.Map KernelName (KernelSafety, T.Text),
ToOpenCL -> Map Name (Definition, Text)
clDevFuns :: M.Map Name (C.Definition, T.Text),
ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass,
ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg],
ToOpenCL -> [(Name, KernelConstExp)]
clConstants :: [(Name, KernelConstExp)]
}
initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map Name (KernelSafety, Text)
-> Map Name (Definition, Text)
-> Set PrimType
-> Map Name SizeClass
-> [FailureMsg]
-> [(Name, KernelConstExp)]
-> ToOpenCL
ToOpenCL forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
data Env = Env
{ Env -> Functions HostOp
envFuns :: ImpGPU.Functions ImpGPU.HostOp,
Env -> Set Name
envFunsMayFail :: S.Set Name,
Env -> Map Name (Set Name)
envCallGraph :: M.Map Name (S.Set Name)
}
codeMayFail :: (a -> Bool) -> ImpGPU.Code a -> Bool
codeMayFail :: forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
_ (Assert {}) = Bool
True
codeMayFail a -> Bool
f (Op a
x) = a -> Bool
f a
x
codeMayFail a -> Bool
f (Code a
x :>>: Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (For VName
_ Exp
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (While TExp Bool
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (If TExp Bool
_ Code a
x Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (Comment Text
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
_ Code a
_ = Bool
False
hostOpMayFail :: ImpGPU.HostOp -> Bool
hostOpMayFail :: HostOp -> Bool
hostOpMayFail (CallKernel Kernel
k) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail KernelOp -> Bool
kernelOpMayFail forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
hostOpMayFail HostOp
_ = Bool
False
kernelOpMayFail :: ImpGPU.KernelOp -> Bool
kernelOpMayFail :: KernelOp -> Bool
kernelOpMayFail = forall a b. a -> b -> a
const Bool
False
funsMayFail :: M.Map Name (S.Set Name) -> ImpGPU.Functions ImpGPU.HostOp -> S.Set Name
funsMayFail :: Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg (Functions [(Name, Function HostOp)]
funs) =
forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {b}. (Name, b) -> Bool
mayFail [(Name, Function HostOp)]
funs
where
base_mayfail :: [Name]
base_mayfail =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. (a -> Bool) -> Code a -> Bool
codeMayFail HostOp -> Bool
hostOpMayFail forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FunctionT a -> Code a
ImpGPU.functionBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, Function HostOp)]
funs
mayFail :: (Name, b) -> Bool
mayFail (Name
fname, b
_) =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
base_mayfail) forall a b. (a -> b) -> a -> b
$ Name
fname forall a. a -> [a] -> [a]
: forall a. Set a -> [a]
S.toList (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg)
envFromProg :: ImpGPU.Program -> Env
envFromProg :: Program -> Env
envFromProg Program
prog = Functions HostOp -> Set Name -> Map Name (Set Name) -> Env
Env Functions HostOp
funs (Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg Functions HostOp
funs) Map Name (Set Name)
cg
where
funs :: Functions HostOp
funs = forall a. Definitions a -> Functions a
defFuns Program
prog
cg :: Map Name (Set Name)
cg = forall a. (a -> Set Name) -> Functions a -> Map Name (Set Name)
ImpGPU.callGraph HostOp -> Set Name
calledInHostOp Functions HostOp
funs
lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp)
lookupFunction :: Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Functions a -> [(Name, Function a)]
unFunctions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Functions HostOp
envFuns
functionMayFail :: Name -> Env -> Bool
functionMayFail :: Name -> Env -> Bool
functionMayFail Name
fname = forall a. Ord a => a -> Set a -> Bool
S.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set Name
envFunsMayFail
type OnKernelM = ReaderT Env (State ToOpenCL)
addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map Name SizeClass
clSizes = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s}
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpGPU.GetSize VName
v Name
key SizeClass
size_class) = do
Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpGPU.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpGPU.GetSizeMax VName
v SizeClass
size_class) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class
genGPUCode ::
Env ->
OpsMode ->
KernelCode ->
[FailureMsg] ->
GC.CompilerM KernelOp KernelState a ->
(a, GC.CompilerState KernelState)
genGPUCode :: forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
mode Code KernelOp
body [FailureMsg]
failures =
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
(Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body)
VNameSource
blankNameSource
([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)
generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
generateDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
device_func = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> [Param]
functionInput Function KernelOp
device_func) forall {a}. a
bad
Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
[FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
let ((Definition, Func)
func, KernelState
kstate) =
if Name -> Env -> Bool
functionMayFail Name
fname Env
env
then
let params :: [Param]
params =
[ [C.cparam|__global int *global_failure|],
[C.cparam|__global typename int64_t *global_failure_args|]
]
((Definition, Func)
f, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. Code a -> Code a
declsFirst forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun forall a. Monoid a => a
mempty [Param]
params (Name
fname, Function KernelOp
device_func)
in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
else
let ((Definition, Func)
f, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. Code a -> Code a
declsFirst forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
forall op s.
[BlockItem]
-> (Name, Function op) -> CompilerM op s (Definition, Func)
GC.compileVoidFun forall a. Monoid a => a
mempty (Name
fname, Function KernelOp
device_func)
in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
ToOpenCL
s
{ clUsedTypes :: Set PrimType
clUsedTypes = Code KernelOp -> Set PrimType
typesInCode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
clDevFuns :: Map Name (Definition, Text)
clDevFuns = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Func -> Text
funcText (Definition, Func)
func) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (Definition, Text)
clDevFuns ToOpenCL
s,
clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
}
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func
where
memParam :: Param -> Bool
memParam MemParam {} = Bool
True
memParam ScalarParam {} = Bool
False
bad :: a
bad = forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that use arrays."
ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
ensureDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
host_func = do
Bool
exists <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Bool
M.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map Name (Definition, Text)
clDevFuns
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
host_func
calledInHostOp :: HostOp -> S.Set Name
calledInHostOp :: HostOp -> Set Name
calledInHostOp (CallKernel Kernel
k) = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
calledInHostOp HostOp
_ = forall a. Monoid a => a
mempty
calledInKernelOp :: KernelOp -> S.Set Name
calledInKernelOp :: KernelOp -> Set Name
calledInKernelOp = forall a b. a -> b -> a
const forall a. Monoid a => a
mempty
ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns Code KernelOp
code = do
let called :: Set Name
called = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp Code KernelOp
code
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Set a -> [a]
S.toList Set Name
called) forall a b. (a -> b) -> a -> b
$ \Name
fname -> do
Maybe (Function HostOp)
def <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname
case Maybe (Function HostOp)
def of
Just Function HostOp
host_func -> do
let device_func :: Function KernelOp
device_func = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
device_func
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Name
fname
Maybe (Function HostOp)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
where
bad :: a
bad = forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that contain parallelism."
toDevice :: HostOp -> KernelOp
toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = forall {a}. a
bad
isConst :: GroupDim -> Maybe T.Text
isConst :: GroupDim -> Maybe Text
isConst (Left (ValueExp (IntValue IntValue
x))) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
intToInt64 IntValue
x
isConst (Right (SizeConst Name
v SizeClass
_)) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
v
isConst (Right (SizeMaxConst SizeClass
size_class)) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Text
"max_" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SizeClass
size_class
isConst GroupDim
_ = forall a. Maybe a
Nothing
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
kernel = do
[Name]
called <- Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
[FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
KernelMode (Kernel -> Code KernelOp
kernelBody Kernel
kernel) [FailureMsg]
failures forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ do
[BlockItem]
body <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ forall op s. Code op -> CompilerM op s ()
GC.compileCode forall a b. (a -> b) -> a -> b
$ forall a. Code a -> Code a
declsFirst forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op s. CompilerM op s [BlockItem]
GC.declAllocatedMem
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item [BlockItem]
body
kstate :: KernelState
kstate = forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate
([(Name, KernelConstExp)]
kernel_consts, ([BlockItem]
const_defs, [BlockItem]
const_undefs)) =
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Name
-> KernelUse
-> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
constDef (Kernel -> Name
kernelName Kernel
kernel)) forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel
let (Count Bytes (TPrimExp Int64 VName)
local_memory_bytes, ([Param]
local_memory_params, [KernelArg]
local_memory_args, [BlockItem]
local_memory_init)) =
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$
forall s a. State s a -> s -> a
evalState
(forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM forall {k} {k} {k} {m :: * -> *} {t :: k} {u :: k} {u :: k}.
(MonadFreshNames m, IntExp t) =>
Count u (TPrimExp t VName)
-> (VName, Count u (TPrimExp t VName))
-> m (Count Bytes (TPrimExp t VName),
(Param, KernelArg, BlockItem))
prepareLocalMemory Count Bytes (TPrimExp Int64 VName)
0 (KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate))
VNameSource
blankNameSource
let ([Param]
use_params, [[BlockItem]]
unpack_params) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (Param, [BlockItem])
useAsParam forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel
let (KernelSafety
safety, [BlockItem]
error_init)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
called =
( KernelSafety
SafetyFull,
[C.citems|volatile __local int local_failure;
// Harmless for all threads to write this.
local_failure = 0;|]
)
| forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
then (KernelSafety
SafetyNone, [])
else
( KernelSafety
SafetyCheap,
[C.citems|if (*global_failure >= 0) { return; }|]
)
| Bool
otherwise =
if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
then
( KernelSafety
SafetyFull,
[C.citems|if (*global_failure >= 0) { return; }|]
)
else
( KernelSafety
SafetyFull,
[C.citems|
volatile __local int local_failure;
if (failure_is_an_option) {
int failed = *global_failure >= 0;
if (failed) {
return;
}
}
// All threads write this value - it looks like CUDA has a compiler bug otherwise.
local_failure = 0;
barrier(CLK_LOCAL_MEM_FENCE);
|]
)
failure_params :: [Param]
failure_params =
[ [C.cparam|__global int *global_failure|],
[C.cparam|int failure_is_an_option|],
[C.cparam|__global typename int64_t *global_failure_args|]
]
([Param]
local_memory_param, [BlockItem]
prepare_local_memory) =
case KernelTarget
target of
KernelTarget
TargetOpenCL ->
( [[C.cparam|__local typename uint64_t* local_mem_aligned|]],
[C.citems|__local unsigned char* local_mem = local_mem_aligned;|]
)
KernelTarget
TargetCUDA -> (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
KernelTarget
TargetHIP -> (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
params :: [Param]
params =
[Param]
local_memory_param
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
forall a. [a] -> [a] -> [a]
++ [Param]
local_memory_params
forall a. [a] -> [a] -> [a]
++ [Param]
use_params
attribute :: Text
attribute =
case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM GroupDim -> Maybe Text
isConst forall a b. (a -> b) -> a -> b
$ Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel of
Just [Text
x, Text
y, Text
z] ->
Text
"FUTHARK_KERNEL_SIZED" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Text
z) forall a. Semigroup a => a -> a -> a
<> Text
"\n"
Just [Text
x, Text
y] ->
Text
"FUTHARK_KERNEL_SIZED" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"\n"
Just [Text
x] ->
Text
"FUTHARK_KERNEL_SIZED" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Int
1 :: Int, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"\n"
Maybe [Text]
_ -> Text
"FUTHARK_KERNEL\n"
kernel_fun :: Text
kernel_fun =
Text
attribute
forall a. Semigroup a => a -> a -> a
<> Func -> Text
funcText
[C.cfun|void $id:name ($params:params) {
$items:(mconcat unpack_params)
$items:const_defs
$items:prepare_local_memory
$items:local_memory_init
$items:error_init
$items:kernel_body
$id:(errorLabel kstate): return;
$items:const_undefs
}|]
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
ToOpenCL
s
{ clGPU :: Map Name (KernelSafety, Text)
clGPU = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name (KernelSafety
safety, Text
kernel_fun) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU ToOpenCL
s,
clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate,
clConstants :: [(Name, KernelConstExp)]
clConstants = [(Name, KernelConstExp)]
kernel_consts forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> [(Name, KernelConstExp)]
clConstants ToOpenCL
s
}
let args :: [KernelArg]
args = [KernelArg]
local_memory_args forall a. [a] -> [a] -> [a]
++ Kernel -> [KernelArg]
kernelArgs Kernel
kernel
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelSafety
-> Name
-> Count Bytes (TPrimExp Int64 VName)
-> [KernelArg]
-> [Exp]
-> [GroupDim]
-> OpenCL
LaunchKernel KernelSafety
safety Name
name Count Bytes (TPrimExp Int64 VName)
local_memory_bytes [KernelArg]
args [Exp]
num_groups [GroupDim]
group_size
where
name :: Name
name = Kernel -> Name
kernelName Kernel
kernel
num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
group_size :: [GroupDim]
group_size = Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel
padTo8 :: a -> a
padTo8 a
e = a
e forall a. Num a => a -> a -> a
+ ((a
8 forall a. Num a => a -> a -> a
- (a
e forall e. IntegralExp e => e -> e -> e
`rem` a
8)) forall e. IntegralExp e => e -> e -> e
`rem` a
8)
prepareLocalMemory :: Count u (TPrimExp t VName)
-> (VName, Count u (TPrimExp t VName))
-> m (Count Bytes (TPrimExp t VName),
(Param, KernelArg, BlockItem))
prepareLocalMemory (Count TPrimExp t VName
offset) (VName
mem, Count TPrimExp t VName
size) = do
VName
param <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
mem forall a. [a] -> [a] -> [a]
++ [Char]
"_offset"
let offset' :: TPrimExp t VName
offset' = TPrimExp t VName
offset forall a. Num a => a -> a -> a
+ forall {a}. IntegralExp a => a -> a
padTo8 TPrimExp t VName
size
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a. a -> Count Bytes a
bytes TPrimExp t VName
offset',
( [C.cparam|typename int64_t $id:param|],
Exp -> PrimType -> KernelArg
ValueKArg (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
offset) forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64,
[C.citem|volatile __local $ty:defaultMemBlockType $id:mem = &local_mem[$id:param];|]
)
)
useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem])
useAsParam :: KernelUse -> Maybe (Param, [BlockItem])
useAsParam (ScalarUse VName
name PrimType
pt) = do
let name_bits :: Text
name_bits = Text -> Text
zEncodeText (forall a. Pretty a => a -> Text
prettyText VName
name) forall a. Semigroup a => a -> a -> a
<> Text
"_bits"
ctp :: Type
ctp = case PrimType
pt of
PrimType
Bool -> [C.cty|unsigned char|]
PrimType
Unit -> [C.cty|unsigned char|]
PrimType
_ -> PrimType -> Type
primStorageType PrimType
pt
if Type
ctp forall a. Eq a => a -> a -> Bool
== PrimType -> Type
primTypeToCType PrimType
pt
then forall a. a -> Maybe a
Just ([C.cparam|$ty:ctp $id:name|], [])
else
let name_bits_e :: Exp
name_bits_e = [C.cexp|$id:name_bits|]
in forall a. a -> Maybe a
Just
( [C.cparam|$ty:ctp $id:name_bits|],
[[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]]
)
useAsParam (MemoryUse VName
name) =
forall a. a -> Maybe a
Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], [])
useAsParam ConstUse {} =
forall a. Maybe a
Nothing
constDef :: Name -> KernelUse -> Maybe ((Name, KernelConstExp), (C.BlockItem, C.BlockItem))
constDef :: Name
-> KernelUse
-> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
constDef Name
kernel_name (ConstUse VName
v KernelConstExp
e) =
forall a. a -> Maybe a
Just
( (Text -> Name
nameFromText Text
v', KernelConstExp
e),
( [C.citem|$escstm:(T.unpack def)|],
[C.citem|$escstm:(T.unpack undef)|]
)
)
where
v' :: Text
v' = Text -> Text
zEncodeText forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
kernel_name forall a. Semigroup a => a -> a -> a
<> Text
"." forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
def :: Text
def = Text
"#define " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<> Text
" (" forall a. Semigroup a => a -> a -> a
<> Text
v' forall a. Semigroup a => a -> a -> a
<> Text
")"
undef :: Text
undef = Text
"#undef " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty)
constDef Name
_ KernelUse
_ = forall a. Maybe a
Nothing
commonPrelude :: T.Text
commonPrelude :: Text
commonPrelude =
Text
halfH
forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
forall a. Semigroup a => a -> a -> a
<> Text
transposeCL
forall a. Semigroup a => a -> a -> a
<> Text
copyCL
genOpenClPrelude :: S.Set PrimType -> T.Text
genOpenClPrelude :: Set PrimType -> Text
genOpenClPrelude Set PrimType
ts =
Text
"#define FUTHARK_OPENCL\n"
forall a. Semigroup a => a -> a -> a
<> Text
enable_f64
forall a. Semigroup a => a -> a -> a
<> Text
preludeCL
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude
where
enable_f64 :: Text
enable_f64
| FloatType -> PrimType
FloatType FloatType
Float64 forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts =
[untrimming|#define FUTHARK_F64_ENABLED|]
| Bool
otherwise = forall a. Monoid a => a
mempty
genCUDAPrelude :: T.Text
genCUDAPrelude :: Text
genCUDAPrelude =
Text
"#define FUTHARK_CUDA\n"
forall a. Semigroup a => a -> a -> a
<> Text
preludeCU
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude
genHIPPrelude :: T.Text
genHIPPrelude :: Text
genHIPPrelude =
Text
"#define FUTHARK_HIP\n"
forall a. Semigroup a => a -> a -> a
<> Text
preludeCU
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude
kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
where
useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
useToArg (ScalarUse VName
v PrimType
pt) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt) PrimType
pt
useToArg ConstUse {} = forall a. Maybe a
Nothing
nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState [Char]
nextErrorLabel =
KernelState -> [Char]
errorLabel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s forall a. Num a => a -> a -> a
+ Int
1}
pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}
hasCommunication :: ImpGPU.KernelCode -> Bool
hasCommunication :: Code KernelOp -> Bool
hasCommunication = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
where
communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
communicates Barrier {} = Bool
True
communicates KernelOp
_ = Bool
False
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c== :: OpsMode -> OpsMode -> Bool
Eq)
inKernelOperations ::
Env ->
OpsMode ->
ImpGPU.KernelCode ->
GC.Operations KernelOp KernelState
inKernelOperations :: Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body =
GC.Operations
{ opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = forall {f :: * -> *}. Applicative f => [Char] -> f Type
kernelMemoryType,
opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = forall {op} {s}. WriteScalar op s
kernelWriteScalar,
opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = forall {op} {s}. ReadScalar op s
kernelReadScalar,
opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
opsDeallocate :: Allocate KernelOp KernelState
GC.opsDeallocate = Allocate KernelOp KernelState
cannotDeallocate,
opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
opsCopies :: Map (Space, Space) (DoLMADCopy KernelOp KernelState)
GC.opsCopies = forall a. Monoid a => a
mempty,
opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = forall {a}.
ToIdent a =>
[a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel,
opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = forall a. Monoid a => a
mempty
}
where
has_communication :: Bool
has_communication = Code KernelOp -> Bool
hasCommunication Code KernelOp
body
fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]
kernelOps :: GC.OpCompiler KernelOp KernelState
kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_group_id($int:i);|]
kernelOps (GetLocalId VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
kernelOps (GetLocalSize VName
v Int
i) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
kernelOps (GetLockstepWidth VName
v) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
kernelOps (Barrier Fence
f) = do
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
kernelOps (MemFence Fence
FenceLocal) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
kernelOps (MemFence Fence
FenceGlobal) =
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
kernelOps (LocalAlloc VName
name Count Bytes (TPrimExp Int64 VName)
size) = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString VName
name forall a. [a] -> [a] -> [a]
++ [Char]
"_backing"
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
KernelState
s {kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', Count Bytes (TPrimExp Int64 VName)
size) forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s}
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|]
kernelOps (ErrorSync Fence
f) = do
[Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
Bool
pending <- KernelState -> Bool
kernelSyncPending forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending forall a b. (a -> b) -> a -> b
$ do
Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
CompilerM KernelOp KernelState ()
incErrorLabel
kernelOps (Atomic Space
space AtomicOp
aop) = forall {op} {s}. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop
atomicCast :: Space -> Type -> f Type
atomicCast Space
s Type
t = do
let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
let quals :: [TypeQual]
quals = case Space
s of
Space [Char]
sid -> [Char] -> [TypeQual]
pointerQuals [Char]
sid
Space
_ -> [Char] -> [TypeQual]
pointerQuals [Char]
"global"
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volatile++quals) $ty:t|]
atomicSpace :: Space -> [Char]
atomicSpace (Space [Char]
sid) = [Char]
sid
atomicSpace Space
_ = [Char]
"global"
doAtomic :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val [Char]
op Type
ty = do
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
where
op' :: [Char]
op' = [Char]
op forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
doAtomicCmpXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
cmp Exp
val Type
ty = do
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
cmp' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
where
op :: [Char]
op = [Char]
"atomic_cmpxchg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
doAtomicXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val Type
ty = do
Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
where
op :: [Char]
op = [Char]
"atomic_chg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_add" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_fadd" [C.cty|double|]
atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_smax" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_smin" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int64_t|]
atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int64_t|]
atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_and" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_or" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_xor" [C.cty|typename int64_t|]
atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
cmp Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [C.cty|typename int64_t|]
atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_add" [C.cty|int|]
atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_fadd" [C.cty|float|]
atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_smax" [C.cty|int|]
atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_smin" [C.cty|int|]
atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int|]
atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int|]
atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_and" [C.cty|int|]
atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_or" [C.cty|int|]
atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [Char]
"atomic_xor" [C.cty|int|]
atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
cmp Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
cmp Exp
val [C.cty|int|]
atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val) =
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TPrimExp Int64 VName)
ind Exp
val [C.cty|int|]
cannotAllocate :: GC.Allocate KernelOp KernelState
cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot allocate memory in kernel"
cannotDeallocate :: GC.Deallocate KernelOp KernelState
cannotDeallocate :: Allocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot deallocate memory in kernel"
copyInKernel :: GC.Copy KernelOp KernelState
copyInKernel :: Copy KernelOp KernelState
copyInKernel CopyBarrier
_ Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot bulk copy in kernel."
kernelMemoryType :: [Char] -> f Type
kernelMemoryType [Char]
space =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(pointerQuals space) $ty:defaultMemBlockType|]
kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
forall op s. ([Char] -> [TypeQual]) -> WriteScalar op s
GC.writeScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals
kernelReadScalar :: ReadScalar op s
kernelReadScalar =
forall op s. ([Char] -> [TypeQual]) -> ReadScalar op s
GC.readScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals
whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
[Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
if Bool
has_communication
then [C.citems|local_failure = 1; goto $id:label;|]
else
if OpsMode
mode forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
then [C.citems|return 1;|]
else [C.citems|return;|]
callInKernel :: [a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel [a]
dests Name
fname [Exp]
args
| Name -> Env -> Bool
functionMayFail Name
fname Env
env = do
let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
args' :: [Exp]
args' =
[C.cexp|global_failure|]
forall a. a -> [a] -> [a]
: [C.cexp|global_failure_args|]
forall a. a -> [a] -> [a]
: [Exp]
out_args
forall a. [a] -> [a] -> [a]
++ [Exp]
args
[BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]
| Bool
otherwise = do
let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
args' :: [Exp]
args' = [Exp]
out_args forall a. [a] -> [a] -> [a]
++ [Exp]
args
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|$id:(funName fname)($args:args');|]
errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) [Char]
backtrace = do
Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> [Char] -> FailureMsg
FailureMsg ErrorMsg Exp
msg [Char]
backtrace]}
let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
setArgs a
i (ErrorVal PrimType
_ Exp
x : [ErrorMsgPart Exp]
parts') = do
Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
[Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] forall a. a -> [a] -> [a]
: [Stm]
stms
[Stm]
argstms <- forall {a} {op} {s}.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts
[BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
forall op s. Stm -> CompilerM op s ()
GC.stm
[C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
{ $stms:argstms; }
$items:what_next
}|]
typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = Code KernelOp -> Set PrimType
typesInCode forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
typesInCode :: ImpGPU.KernelCode -> S.Set PrimType
typesInCode :: Code KernelOp -> Set PrimType
typesInCode Code KernelOp
Skip = forall a. Monoid a => a
mempty
typesInCode (Code KernelOp
c1 :>>: Code KernelOp
c2) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (For VName
_ Exp
e Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (While (TPrimExp Exp
e) Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode DeclareMem {} = forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ PrimType
t ArrayContents
_) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = forall a. Monoid a => a
mempty
typesInCode (LMADCopy PrimType
_ [Count Elements (TPrimExp Int64 VName)]
shape (VName, Space)
_ (Count (TPrimExp Exp
dstoffset), [Count Elements (TPrimExp Int64 VName)]
dststrides) (VName, Space)
_ (Count (TPrimExp Exp
srcoffset), [Count Elements (TPrimExp Int64 VName)]
srcstrides)) =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TPrimExp Int64 VName)]
shape
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
dstoffset
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TPrimExp Int64 VName)]
dststrides
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
srcoffset
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TPrimExp Int64 VName)]
srcstrides
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (Read VName
_ VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_) =
Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
where
typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = forall a. Monoid a => a
mempty
typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) Code KernelOp
c1 Code KernelOp
c2) =
Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment Text
_ Code KernelOp
c) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (DebugPrint [Char]
_ Maybe Exp
v) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode (TracePrint ErrorMsg Exp
msg) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set PrimType
typesInExp ErrorMsg Exp
msg
typesInCode Op {} = forall a. Monoid a => a
mempty
typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
where
(PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp [Char]
_ [Exp]
args PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp LeafExp {} = forall a. Monoid a => a
mempty