{-# LANGUAGE CPP             #-}
{-# LANGUAGE QuasiQuotes     #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Embed
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Embed (

  module Data.Array.Accelerate.LLVM.Embed,

) where

import Data.ByteString.Short.Extra                                  as BS

import Data.Array.Accelerate.Lifetime

import Data.Array.Accelerate.LLVM.Compile
import Data.Array.Accelerate.LLVM.Embed

import Data.Array.Accelerate.LLVM.PTX.Compile
import Data.Array.Accelerate.LLVM.PTX.Link
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Context

import qualified Foreign.CUDA.Driver                                as CUDA

import Foreign.Ptr
import GHC.Ptr                                                      ( Ptr(..) )
import Language.Haskell.TH                                          ( Q, TExp )
import System.IO.Unsafe
import qualified Data.ByteString                                    as B
import qualified Data.ByteString.Unsafe                             as B
import qualified Language.Haskell.TH                                as TH
import qualified Language.Haskell.TH.Syntax                         as TH


instance Embed PTX where
  embedForTarget :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embedForTarget = PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed

-- Embed the given object code and set up to be reloaded at execution time.
--
embed :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed :: PTX -> ObjectR PTX -> Q (TExp (ExecutableR PTX))
embed PTX
target (ObjectR _ cfg obj) = do
  -- Load the module to recover information such as number of registers and
  -- bytes of shared memory. It may be possible to do this without requiring an
  -- active CUDA context.
  [(Kernel, Q (TExp (Int -> Int)))]
kmd <- IO [(Kernel, Q (TExp (Int -> Int)))]
-> Q [(Kernel, Q (TExp (Int -> Int)))]
forall a. IO a -> Q a
TH.runIO (IO [(Kernel, Q (TExp (Int -> Int)))]
 -> Q [(Kernel, Q (TExp (Int -> Int)))])
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> Q [(Kernel, Q (TExp (Int -> Int)))]
forall a b. (a -> b) -> a -> b
$ Context
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall a. Context -> IO a -> IO a
withContext (PTX -> Context
ptxContext PTX
target) (IO [(Kernel, Q (TExp (Int -> Int)))]
 -> IO [(Kernel, Q (TExp (Int -> Int)))])
-> IO [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall a b. (a -> b) -> a -> b
$ do
            JITResult
jit <- ByteString -> (CString -> IO JITResult) -> IO JITResult
forall a. ByteString -> (CString -> IO a) -> IO a
B.unsafeUseAsCString ByteString
obj ((CString -> IO JITResult) -> IO JITResult)
-> (CString -> IO JITResult) -> IO JITResult
forall a b. (a -> b) -> a -> b
$ \CString
p -> Ptr Word8 -> [JITOption] -> IO JITResult
CUDA.loadDataFromPtrEx (CString -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr CString
p) []
            [(Kernel, Q (TExp (Int -> Int)))]
ks  <- ((ShortByteString, LaunchConfig)
 -> IO (Kernel, Q (TExp (Int -> Int))))
-> [(ShortByteString, LaunchConfig)]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((ShortByteString
 -> LaunchConfig -> IO (Kernel, Q (TExp (Int -> Int))))
-> (ShortByteString, LaunchConfig)
-> IO (Kernel, Q (TExp (Int -> Int)))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ (JITResult -> Module
CUDA.jitModule JITResult
jit))) [(ShortByteString, LaunchConfig)]
cfg
            Module -> IO ()
CUDA.unload (JITResult -> Module
CUDA.jitModule JITResult
jit)
            [(Kernel, Q (TExp (Int -> Int)))]
-> IO [(Kernel, Q (TExp (Int -> Int)))]
forall (m :: * -> *) a. Monad m => a -> m a
return [(Kernel, Q (TExp (Int -> Int)))]
ks

  -- Generate the embedded kernel executable. This will load the embedded object
  -- code into the current (at execution time) context.
  [|| unsafePerformIO $ do
        jit <- CUDA.loadDataFromPtrEx $$( TH.unsafeTExpCoerce [| Ptr $(TH.litE (TH.StringPrimL (B.unpack obj))) |] ) []
        fun <- newLifetime (FunctionTable $$(listE (map (linkQ 'jit) kmd)))
        return $ PTXR fun
   ||]
  where
    linkQ :: TH.Name -> (Kernel, Q (TExp (Int -> Int))) -> Q (TExp Kernel)
    linkQ :: Name -> (Kernel, Q (TExp (Int -> Int))) -> Q (TExp Kernel)
linkQ Name
jit (Kernel ShortByteString
name Fun
_ Int
dsmem Int
cta Int -> Int
_, Q (TExp (Int -> Int))
grid) =
      [|| unsafePerformIO $ do
            f <- CUDA.getFun (CUDA.jitModule $$(TH.unsafeTExpCoerce (TH.varE jit))) $$(liftSBS name)
            return $ Kernel $$(liftSBS name) f dsmem cta $$grid
       ||]

    listE :: [Q (TExp a)] -> Q (TExp [a])
    listE :: [Q (TExp a)] -> Q (TExp [a])
listE [Q (TExp a)]
xs = Q Exp -> Q (TExp [a])
forall a. Q Exp -> Q (TExp a)
TH.unsafeTExpCoerce ([Q Exp] -> Q Exp
TH.listE ((Q (TExp a) -> Q Exp) -> [Q (TExp a)] -> [Q Exp]
forall a b. (a -> b) -> [a] -> [b]
map Q (TExp a) -> Q Exp
forall a. Q (TExp a) -> Q Exp
TH.unTypeQ [Q (TExp a)]
xs))