{-# LANGUAGE CPP             #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies    #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Link
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Link (

  module Data.Array.Accelerate.LLVM.Link,
  ExecutableR(..), FunctionTable(..), Kernel(..), ObjectCode,
  linkFunctionQ,

) where

import Data.Array.Accelerate.Lifetime

import Data.Array.Accelerate.LLVM.Link
import Data.Array.Accelerate.LLVM.State

import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.Compile
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Link.Cache
import Data.Array.Accelerate.LLVM.PTX.Link.Object
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Data.Array.Accelerate.LLVM.PTX.Debug               as Debug

-- cuda
import qualified Foreign.CUDA.Analysis                              as CUDA
import qualified Foreign.CUDA.Driver                                as CUDA

-- standard library
import Control.Monad.State
import Data.ByteString.Short.Char8                                  ( ShortByteString, unpack )
import Foreign.Ptr
import Language.Haskell.TH
import Text.Printf                                                  ( printf )
import qualified Data.ByteString.Unsafe                             as B
import Prelude                                                      as P hiding ( lookup )


instance Link PTX where
  data ExecutableR PTX = PTXR { ExecutableR PTX -> Lifetime FunctionTable
ptxExecutable :: {-# UNPACK #-} !(Lifetime FunctionTable)
                              }
  linkForTarget :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
linkForTarget = ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link


-- | Load the generated object code into the current CUDA context.
--
link :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link (ObjectR uid cfg obj) = do
  PTX
target <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
  KernelTable
cache  <- (PTX -> KernelTable) -> LLVM PTX KernelTable
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> KernelTable
ptxKernelTable
  Lifetime FunctionTable
funs   <- IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable))
-> IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable)
forall a b. (a -> b) -> a -> b
$ UID
-> KernelTable
-> IO (FunctionTable, ObjectCode)
-> IO (Lifetime FunctionTable)
forall f o. UID -> LinkCache f o -> IO (f, o) -> IO (Lifetime f)
dlsym UID
uid KernelTable
cache (IO (FunctionTable, ObjectCode) -> IO (Lifetime FunctionTable))
-> IO (FunctionTable, ObjectCode) -> IO (Lifetime FunctionTable)
forall a b. (a -> b) -> a -> b
$ do
    -- Load the SASS object code into the current CUDA context
    JITResult
jit <- ByteString -> (CString -> IO JITResult) -> IO JITResult
forall a. ByteString -> (CString -> IO a) -> IO a
B.unsafeUseAsCString ByteString
obj ((CString -> IO JITResult) -> IO JITResult)
-> (CString -> IO JITResult) -> IO JITResult
forall a b. (a -> b) -> a -> b
$ \CString
p -> Ptr Word8 -> [JITOption] -> IO JITResult
CUDA.loadDataFromPtrEx (CString -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr CString
p) []
    let mdl :: Module
mdl = JITResult -> Module
CUDA.jitModule JITResult
jit

    -- Extract the kernel functions
    FunctionTable
nm  <- [Kernel] -> FunctionTable
FunctionTable ([Kernel] -> FunctionTable) -> IO [Kernel] -> IO FunctionTable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ((ShortByteString, LaunchConfig) -> IO Kernel)
-> [(ShortByteString, LaunchConfig)] -> IO [Kernel]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((ShortByteString -> LaunchConfig -> IO Kernel)
-> (ShortByteString, LaunchConfig) -> IO Kernel
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Module -> ShortByteString -> LaunchConfig -> IO Kernel
linkFunction Module
mdl)) [(ShortByteString, LaunchConfig)]
cfg
    ObjectCode
oc  <- Module -> IO ObjectCode
forall a. a -> IO (Lifetime a)
newLifetime Module
mdl

    -- Finalise the module by unloading it from the CUDA context
    ObjectCode -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer ObjectCode
oc (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_ld (String
"ld: unload module: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ FunctionTable -> String
forall a. Show a => a -> String
show FunctionTable
nm)
      Context -> IO () -> IO ()
forall a. Context -> IO a -> IO a
withContext (PTX -> Context
ptxContext PTX
target) (Module -> IO ()
CUDA.unload Module
mdl)

    (FunctionTable, ObjectCode) -> IO (FunctionTable, ObjectCode)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunctionTable
nm, ObjectCode
oc)
  --
  ExecutableR PTX -> LLVM PTX (ExecutableR PTX)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExecutableR PTX -> LLVM PTX (ExecutableR PTX))
-> ExecutableR PTX -> LLVM PTX (ExecutableR PTX)
forall a b. (a -> b) -> a -> b
$! Lifetime FunctionTable -> ExecutableR PTX
PTXR Lifetime FunctionTable
funs


-- | Extract the named function from the module and package into a Kernel
-- object, which includes meta-information on resource usage.
--
-- If we are in debug mode, print statistics on kernel resource usage, etc.
--
linkFunction
    :: CUDA.Module                      -- the compiled module
    -> ShortByteString                  -- __global__ entry function name
    -> LaunchConfig                     -- launch configuration for this global function
    -> IO Kernel
linkFunction :: Module -> ShortByteString -> LaunchConfig -> IO Kernel
linkFunction Module
mdl ShortByteString
name LaunchConfig
configure =
  (Kernel, Q (TExp (Int -> Int))) -> Kernel
forall a b. (a, b) -> a
fst ((Kernel, Q (TExp (Int -> Int))) -> Kernel)
-> IO (Kernel, Q (TExp (Int -> Int))) -> IO Kernel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ Module
mdl ShortByteString
name LaunchConfig
configure

linkFunctionQ
    :: CUDA.Module
    -> ShortByteString
    -> LaunchConfig
    -> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ :: Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ Module
mdl ShortByteString
name LaunchConfig
configure = do
  Fun
f     <- Module -> ShortByteString -> IO Fun
CUDA.getFun Module
mdl ShortByteString
name
  Int
regs  <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.NumRegs
  Int
ssmem <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.SharedSizeBytes
  Int
cmem  <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.ConstSizeBytes
  Int
lmem  <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.LocalSizeBytes
  Int
maxt  <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.MaxKernelThreadsPerBlock

  let
      (Occupancy
occ, Int
cta, Int -> Int
grid, Int
dsmem, Q (TExp (Int -> Int))
gridQ) = LaunchConfig
configure Int
maxt Int
regs Int
ssmem

      msg1, msg2 :: String
      msg1 :: String
msg1 = String -> String -> Int -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"kernel function '%s' used %d registers, %d bytes smem, %d bytes lmem, %d bytes cmem"
                      (ShortByteString -> String
unpack ShortByteString
name) Int
regs (Int
ssmem Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dsmem) Int
lmem Int
cmem

      msg2 :: String
msg2 = String -> Double -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"multiprocessor occupancy %.1f %% : %d threads over %d warps in %d blocks"
                      (Occupancy -> Double
CUDA.occupancy100 Occupancy
occ)
                      (Occupancy -> Int
CUDA.activeThreads Occupancy
occ)
                      (Occupancy -> Int
CUDA.activeWarps Occupancy
occ)
                      (Occupancy -> Int
CUDA.activeThreadBlocks Occupancy
occ)

  Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_cc (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"cc: %s\n               %s" String
msg1 String
msg2)
  (Kernel, Q (TExp (Int -> Int)))
-> IO (Kernel, Q (TExp (Int -> Int)))
forall (m :: * -> *) a. Monad m => a -> m a
return (ShortByteString -> Fun -> Int -> Int -> (Int -> Int) -> Kernel
Kernel ShortByteString
name Fun
f Int
dsmem Int
cta Int -> Int
grid, Q (TExp (Int -> Int))
gridQ)


{--
-- | Extract the names of the function definitions from the module.
--
-- Note: [Extracting global function names]
--
-- It is important to run this on the module given to us by code generation.
-- After combining modules with 'libdevice', extra function definitions,
-- corresponding to basic maths operations, will be added to the module. These
-- functions will not be callable as __global__ functions.
--
-- The list of names will be exported in the order that they appear in the
-- module.
--
globalFunctions :: [Definition] -> [String]
globalFunctions defs =
  [ n | GlobalDefinition Function{..} <- defs
      , not (null basicBlocks)
      , let Name n = name
      ]
--}