{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Link (
module Data.Array.Accelerate.LLVM.Link,
ExecutableR(..), FunctionTable(..), Kernel(..), ObjectCode,
linkFunctionQ,
) where
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.Link
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.Compile
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Link.Cache
import Data.Array.Accelerate.LLVM.PTX.Link.Object
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import qualified Foreign.CUDA.Analysis as CUDA
import qualified Foreign.CUDA.Driver as CUDA
import Control.Monad.State
import Data.ByteString.Short.Char8 ( ShortByteString, unpack )
import Foreign.Ptr
import Language.Haskell.TH
import Text.Printf ( printf )
import qualified Data.ByteString.Unsafe as B
import Prelude as P hiding ( lookup )
instance Link PTX where
data ExecutableR PTX = PTXR { ExecutableR PTX -> Lifetime FunctionTable
ptxExecutable :: {-# UNPACK #-} !(Lifetime FunctionTable)
}
linkForTarget :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
linkForTarget = ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link
link :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link :: ObjectR PTX -> LLVM PTX (ExecutableR PTX)
link (ObjectR uid cfg obj) = do
PTX
target <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
KernelTable
cache <- (PTX -> KernelTable) -> LLVM PTX KernelTable
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> KernelTable
ptxKernelTable
Lifetime FunctionTable
funs <- IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable))
-> IO (Lifetime FunctionTable) -> LLVM PTX (Lifetime FunctionTable)
forall a b. (a -> b) -> a -> b
$ UID
-> KernelTable
-> IO (FunctionTable, ObjectCode)
-> IO (Lifetime FunctionTable)
forall f o. UID -> LinkCache f o -> IO (f, o) -> IO (Lifetime f)
dlsym UID
uid KernelTable
cache (IO (FunctionTable, ObjectCode) -> IO (Lifetime FunctionTable))
-> IO (FunctionTable, ObjectCode) -> IO (Lifetime FunctionTable)
forall a b. (a -> b) -> a -> b
$ do
JITResult
jit <- ByteString -> (CString -> IO JITResult) -> IO JITResult
forall a. ByteString -> (CString -> IO a) -> IO a
B.unsafeUseAsCString ByteString
obj ((CString -> IO JITResult) -> IO JITResult)
-> (CString -> IO JITResult) -> IO JITResult
forall a b. (a -> b) -> a -> b
$ \CString
p -> Ptr Word8 -> [JITOption] -> IO JITResult
CUDA.loadDataFromPtrEx (CString -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr CString
p) []
let mdl :: Module
mdl = JITResult -> Module
CUDA.jitModule JITResult
jit
FunctionTable
nm <- [Kernel] -> FunctionTable
FunctionTable ([Kernel] -> FunctionTable) -> IO [Kernel] -> IO FunctionTable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ((ShortByteString, LaunchConfig) -> IO Kernel)
-> [(ShortByteString, LaunchConfig)] -> IO [Kernel]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((ShortByteString -> LaunchConfig -> IO Kernel)
-> (ShortByteString, LaunchConfig) -> IO Kernel
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Module -> ShortByteString -> LaunchConfig -> IO Kernel
linkFunction Module
mdl)) [(ShortByteString, LaunchConfig)]
cfg
ObjectCode
oc <- Module -> IO ObjectCode
forall a. a -> IO (Lifetime a)
newLifetime Module
mdl
ObjectCode -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer ObjectCode
oc (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_ld (String
"ld: unload module: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ FunctionTable -> String
forall a. Show a => a -> String
show FunctionTable
nm)
Context -> IO () -> IO ()
forall a. Context -> IO a -> IO a
withContext (PTX -> Context
ptxContext PTX
target) (Module -> IO ()
CUDA.unload Module
mdl)
(FunctionTable, ObjectCode) -> IO (FunctionTable, ObjectCode)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunctionTable
nm, ObjectCode
oc)
ExecutableR PTX -> LLVM PTX (ExecutableR PTX)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExecutableR PTX -> LLVM PTX (ExecutableR PTX))
-> ExecutableR PTX -> LLVM PTX (ExecutableR PTX)
forall a b. (a -> b) -> a -> b
$! Lifetime FunctionTable -> ExecutableR PTX
PTXR Lifetime FunctionTable
funs
linkFunction
:: CUDA.Module
-> ShortByteString
-> LaunchConfig
-> IO Kernel
linkFunction :: Module -> ShortByteString -> LaunchConfig -> IO Kernel
linkFunction Module
mdl ShortByteString
name LaunchConfig
configure =
(Kernel, Q (TExp (Int -> Int))) -> Kernel
forall a b. (a, b) -> a
fst ((Kernel, Q (TExp (Int -> Int))) -> Kernel)
-> IO (Kernel, Q (TExp (Int -> Int))) -> IO Kernel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ Module
mdl ShortByteString
name LaunchConfig
configure
linkFunctionQ
:: CUDA.Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ :: Module
-> ShortByteString
-> LaunchConfig
-> IO (Kernel, Q (TExp (Int -> Int)))
linkFunctionQ Module
mdl ShortByteString
name LaunchConfig
configure = do
Fun
f <- Module -> ShortByteString -> IO Fun
CUDA.getFun Module
mdl ShortByteString
name
Int
regs <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.NumRegs
Int
ssmem <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.SharedSizeBytes
Int
cmem <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.ConstSizeBytes
Int
lmem <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.LocalSizeBytes
Int
maxt <- Fun -> FunAttribute -> IO Int
CUDA.requires Fun
f FunAttribute
CUDA.MaxKernelThreadsPerBlock
let
(Occupancy
occ, Int
cta, Int -> Int
grid, Int
dsmem, Q (TExp (Int -> Int))
gridQ) = LaunchConfig
configure Int
maxt Int
regs Int
ssmem
msg1, msg2 :: String
msg1 :: String
msg1 = String -> String -> Int -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"kernel function '%s' used %d registers, %d bytes smem, %d bytes lmem, %d bytes cmem"
(ShortByteString -> String
unpack ShortByteString
name) Int
regs (Int
ssmem Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dsmem) Int
lmem Int
cmem
msg2 :: String
msg2 = String -> Double -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"multiprocessor occupancy %.1f %% : %d threads over %d warps in %d blocks"
(Occupancy -> Double
CUDA.occupancy100 Occupancy
occ)
(Occupancy -> Int
CUDA.activeThreads Occupancy
occ)
(Occupancy -> Int
CUDA.activeWarps Occupancy
occ)
(Occupancy -> Int
CUDA.activeThreadBlocks Occupancy
occ)
Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_cc (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"cc: %s\n %s" String
msg1 String
msg2)
(Kernel, Q (TExp (Int -> Int)))
-> IO (Kernel, Q (TExp (Int -> Int)))
forall (m :: * -> *) a. Monad m => a -> m a
return (ShortByteString -> Fun -> Int -> Int -> (Int -> Int) -> Kernel
Kernel ShortByteString
name Fun
f Int
dsmem Int
cta Int -> Int
grid, Q (TExp (Int -> Int))
gridQ)