{-# LANGUAGE CPP #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Embed (
module Data.Array.Accelerate.LLVM.Embed,
) where
import Data.ByteString.Short.Extra as BS
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.Compile
import Data.Array.Accelerate.LLVM.Embed
import Data.Array.Accelerate.LLVM.PTX.Compile
import Data.Array.Accelerate.LLVM.PTX.Link
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Context
import qualified Foreign.CUDA.Driver as CUDA
import Foreign.Ptr
import GHC.Ptr ( Ptr(..) )
import Language.Haskell.TH ( Q, TExp )
import System.IO.Unsafe
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
instance Embed PTX where
embedForTarget :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embedForTarget = PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed
embed :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed PTX
target (ObjectR _ cfg obj) = do
[(Kernel, Q (TExp (Int -> Int)))]
kmd <- IO [(Kernel, Q (TExp (Int -> Int)))]
-> Q [(Kernel, Q (TExp (Int -> Int)))]
forall a. IO a -> Q a
TH.runIO (IO [(Kernel, Q (TExp (Int -> Int)))]
-> Q [(Kernel, Q (TExp (Int -> Int)))])
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> Q [(Kernel, Q (TExp (Int -> Int)))]
forall a b. (a -> b) -> a -> b
$ Context
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall a. Context -> IO a -> IO a
withContext (PTX -> Context
ptxContext PTX
target) (IO [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))])
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall a b. (a -> b) -> a -> b
$ do
JITResult
jit <- ByteString -> (CString -> IO JITResult) -> IO JITResult
forall a. ByteString -> (CString -> IO a) -> IO a
B.unsafeUseAsCString ByteString
obj ((CString -> IO JITResult) -> IO JITResult)
-> (CString -> IO JITResult) -> IO JITResult
forall a b. (a -> b) -> a -> b
$ \CString
p -> Ptr Word8 -> [JITOption] -> IO JITResult
CUDA.loadDataFromPtrEx (CString -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr CString
p) []
[(Kernel, Q (TExp (Int -> Int)))]
ks <- ((ShortByteString, LaunchConfig)
-> IO (Kernel, Q (TExp (Int -> Int))))
-> [(ShortByteString, LaunchConfig)]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((ShortByteString
-> LaunchConfig -> IO (Kernel, Q (TExp (Int -> Int))))
-> (ShortByteString, LaunchConfig)
-> IO (Kernel, Q (TExp (Int -> Int)))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ (JITResult -> Module
CUDA.jitModule JITResult
jit))) [(ShortByteString, LaunchConfig)]
cfg
Module -> IO ()
CUDA.unload (JITResult -> Module
CUDA.jitModule JITResult
jit)
[(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall (m :: * -> *) a. Monad m => a -> m a
return [(Kernel, Q (TExp (Int -> Int)))]
ks
[|| unsafePerformIO $ do
jit <- CUDA.loadDataFromPtrEx $$( TH.unsafeTExpCoerce [| Ptr $(TH.litE (TH.StringPrimL (B.unpack obj))) |] ) []
fun <- newLifetime (FunctionTable $$(listE (map (linkQ 'jit) kmd)))
return $ PTXR fun
||]
where
linkQ :: TH.Name -> (Kernel, Q (TExp (Int -> Int))) -> Q (TExp Kernel)
linkQ :: Name -> (Kernel, Q (TExp (Int -> Int))) -> Q (TExp Kernel)
linkQ Name
jit (Kernel ShortByteString
name Fun
_ Int
dsmem Int
cta Int -> Int
_, Q (TExp (Int -> Int))
grid) =
[|| unsafePerformIO $ do
f <- CUDA.getFun (CUDA.jitModule $$(TH.unsafeTExpCoerce (TH.varE jit))) $$(liftSBS name)
return $ Kernel $$(liftSBS name) f dsmem cta $$grid
||]
listE :: [Q (TExp a)] -> Q (TExp [a])
listE :: [Q (TExp a)] -> Q (TExp [a])
listE [Q (TExp a)]
xs = Q Exp -> Q (TExp [a])
forall a. Q Exp -> Q (TExp a)
TH.unsafeTExpCoerce ([Q Exp] -> Q Exp
TH.listE ((Q (TExp a) -> Q Exp) -> [Q (TExp a)] -> [Q Exp]
forall a b. (a -> b) -> [a] -> [b]
map Q (TExp a) -> Q Exp
forall a. Q (TExp a) -> Q Exp
TH.unTypeQ [Q (TExp a)]
xs))