module Foreign.CUDA.Driver.Module (
Module, JITOption(..), JITTarget(..), JITResult(..),
getFun, getPtr, getTex,
loadFile,
loadData, loadDataFromPtr,
loadDataEx, loadDataFromPtrEx,
unload
) where
import Foreign.CUDA.Analysis.Device
import Foreign.CUDA.Ptr
import Foreign.CUDA.Driver.Error
import Foreign.CUDA.Driver.Exec
import Foreign.CUDA.Driver.Marshal (peekDeviceHandle)
import Foreign.CUDA.Driver.Texture
import Foreign.CUDA.Internal.C2HS
import Foreign
import Foreign.C
import Unsafe.Coerce
import Control.Monad (liftM)
import Control.Exception (throwIO)
import Data.Maybe (mapMaybe)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Internal as B
newtype Module = Module { useModule :: ((Ptr ()))}
deriving (Eq, Show)
data JITOption
= MaxRegisters !Int
| ThreadsPerBlock !Int
| OptimisationLevel !Int
| Target !Compute
| FallbackStrategy !JITFallback
| GenerateDebugInfo
| GenerateLineInfo
| Verbose
deriving (Show)
data JITResult = JITResult
{
jitTime :: !Float,
jitInfoLog :: !ByteString,
jitModule :: !Module
}
deriving (Show)
data JITOptionInternal = JIT_MAX_REGISTERS
| JIT_THREADS_PER_BLOCK
| JIT_WALL_TIME
| JIT_INFO_LOG_BUFFER
| JIT_INFO_LOG_BUFFER_SIZE_BYTES
| JIT_ERROR_LOG_BUFFER
| JIT_ERROR_LOG_BUFFER_SIZE_BYTES
| JIT_OPTIMIZATION_LEVEL
| JIT_TARGET_FROM_CUCONTEXT
| JIT_TARGET
| JIT_FALLBACK_STRATEGY
| JIT_GENERATE_DEBUG_INFO
| JIT_LOG_VERBOSE
| JIT_GENERATE_LINE_INFO
| JIT_CACHE_MODE
| JIT_NUM_OPTIONS
deriving (Eq,Show)
instance Enum JITOptionInternal where
fromEnum JIT_MAX_REGISTERS = 0
fromEnum JIT_THREADS_PER_BLOCK = 1
fromEnum JIT_WALL_TIME = 2
fromEnum JIT_INFO_LOG_BUFFER = 3
fromEnum JIT_INFO_LOG_BUFFER_SIZE_BYTES = 4
fromEnum JIT_ERROR_LOG_BUFFER = 5
fromEnum JIT_ERROR_LOG_BUFFER_SIZE_BYTES = 6
fromEnum JIT_OPTIMIZATION_LEVEL = 7
fromEnum JIT_TARGET_FROM_CUCONTEXT = 8
fromEnum JIT_TARGET = 9
fromEnum JIT_FALLBACK_STRATEGY = 10
fromEnum JIT_GENERATE_DEBUG_INFO = 11
fromEnum JIT_LOG_VERBOSE = 12
fromEnum JIT_GENERATE_LINE_INFO = 13
fromEnum JIT_CACHE_MODE = 14
fromEnum JIT_NUM_OPTIONS = 15
toEnum 0 = JIT_MAX_REGISTERS
toEnum 1 = JIT_THREADS_PER_BLOCK
toEnum 2 = JIT_WALL_TIME
toEnum 3 = JIT_INFO_LOG_BUFFER
toEnum 4 = JIT_INFO_LOG_BUFFER_SIZE_BYTES
toEnum 5 = JIT_ERROR_LOG_BUFFER
toEnum 6 = JIT_ERROR_LOG_BUFFER_SIZE_BYTES
toEnum 7 = JIT_OPTIMIZATION_LEVEL
toEnum 8 = JIT_TARGET_FROM_CUCONTEXT
toEnum 9 = JIT_TARGET
toEnum 10 = JIT_FALLBACK_STRATEGY
toEnum 11 = JIT_GENERATE_DEBUG_INFO
toEnum 12 = JIT_LOG_VERBOSE
toEnum 13 = JIT_GENERATE_LINE_INFO
toEnum 14 = JIT_CACHE_MODE
toEnum 15 = JIT_NUM_OPTIONS
toEnum unmatched = error ("JITOptionInternal.toEnum: Cannot match " ++ show unmatched)
data JITTarget = Compute10
| Compute11
| Compute12
| Compute13
| Compute20
| Compute21
| Compute30
| Compute32
| Compute35
| Compute50
deriving (Eq,Show)
instance Enum JITTarget where
fromEnum Compute10 = 10
fromEnum Compute11 = 11
fromEnum Compute12 = 12
fromEnum Compute13 = 13
fromEnum Compute20 = 20
fromEnum Compute21 = 21
fromEnum Compute30 = 30
fromEnum Compute32 = 32
fromEnum Compute35 = 35
fromEnum Compute50 = 50
toEnum 10 = Compute10
toEnum 11 = Compute11
toEnum 12 = Compute12
toEnum 13 = Compute13
toEnum 20 = Compute20
toEnum 21 = Compute21
toEnum 30 = Compute30
toEnum 32 = Compute32
toEnum 35 = Compute35
toEnum 50 = Compute50
toEnum unmatched = error ("JITTarget.toEnum: Cannot match " ++ show unmatched)
data JITFallback = Ptx
| Binary
deriving (Eq,Show)
instance Enum JITFallback where
fromEnum Ptx = 0
fromEnum Binary = 1
toEnum 0 = Ptx
toEnum 1 = Binary
toEnum unmatched = error ("JITFallback.toEnum: Cannot match " ++ show unmatched)
getFun :: Module -> String -> IO Fun
getFun !mdl !fn = resultIfFound "function" fn =<< cuModuleGetFunction mdl fn
cuModuleGetFunction :: (Module) -> (String) -> IO ((Status), (Fun))
cuModuleGetFunction a2 a3 =
alloca $ \a1' ->
let {a2' = useModule a2} in
withCString a3 $ \a3' ->
cuModuleGetFunction'_ a1' a2' a3' >>= \res ->
let {res' = cToEnum res} in
peekFun a1'>>= \a1'' ->
return (res', a1'')
where peekFun = liftM Fun . peek
getPtr :: Module -> String -> IO (DevicePtr a, Int)
getPtr !mdl !name = do
(!status,!dptr,!bytes) <- cuModuleGetGlobal mdl name
resultIfFound "global" name (status,(dptr,bytes))
cuModuleGetGlobal :: (Module) -> (String) -> IO ((Status), (DevicePtr a), (Int))
cuModuleGetGlobal a3 a4 =
alloca $ \a1' ->
alloca $ \a2' ->
let {a3' = useModule a3} in
withCString a4 $ \a4' ->
cuModuleGetGlobal'_ a1' a2' a3' a4' >>= \res ->
let {res' = cToEnum res} in
peekDeviceHandle a1'>>= \a1'' ->
peekIntConv a2'>>= \a2'' ->
return (res', a1'', a2'')
getTex :: Module -> String -> IO Texture
getTex !mdl !name = resultIfFound "texture" name =<< cuModuleGetTexRef mdl name
cuModuleGetTexRef :: (Module) -> (String) -> IO ((Status), (Texture))
cuModuleGetTexRef a2 a3 =
alloca $ \a1' ->
let {a2' = useModule a2} in
withCString a3 $ \a3' ->
cuModuleGetTexRef'_ a1' a2' a3' >>= \res ->
let {res' = cToEnum res} in
peekTex a1'>>= \a1'' ->
return (res', a1'')
loadFile :: FilePath -> IO Module
loadFile !ptx = resultIfOk =<< cuModuleLoad ptx
cuModuleLoad :: (FilePath) -> IO ((Status), (Module))
cuModuleLoad a2 =
alloca $ \a1' ->
withCString a2 $ \a2' ->
cuModuleLoad'_ a1' a2' >>= \res ->
let {res' = cToEnum res} in
peekMod a1'>>= \a1'' ->
return (res', a1'')
loadData :: ByteString -> IO Module
loadData !img =
B.useAsCString img (\p -> loadDataFromPtr (castPtr p))
loadDataFromPtr :: Ptr Word8 -> IO Module
loadDataFromPtr !img = resultIfOk =<< cuModuleLoadData img
cuModuleLoadData :: (Ptr Word8) -> IO (( Status), (Module))
cuModuleLoadData a2 =
alloca $ \a1' ->
let {a2' = castPtr a2} in
cuModuleLoadData'_ a1' a2' >>= \res ->
let {res' = cToEnum res} in
peekMod a1'>>= \a1'' ->
return (res', a1'')
loadDataEx :: ByteString -> [JITOption] -> IO JITResult
loadDataEx !img !options =
B.useAsCString img (\p -> loadDataFromPtrEx (castPtr p) options)
loadDataFromPtrEx :: Ptr Word8 -> [JITOption] -> IO JITResult
loadDataFromPtrEx !img !options = do
fp_ilog <- B.mallocByteString logSize
allocaArray logSize $ \p_elog -> do
withForeignPtr fp_ilog $ \p_ilog -> do
let (opt,val) = unzip $
[ (JIT_WALL_TIME, 0)
, (JIT_INFO_LOG_BUFFER_SIZE_BYTES, logSize)
, (JIT_ERROR_LOG_BUFFER_SIZE_BYTES, logSize)
, (JIT_INFO_LOG_BUFFER, unsafeCoerce (p_ilog :: CString))
, (JIT_ERROR_LOG_BUFFER, unsafeCoerce (p_elog :: CString)) ] ++ mapMaybe unpack options
withArray (map cFromEnum opt) $ \p_opts -> do
withArray (map unsafeCoerce val) $ \p_vals -> do
(s,mdl) <- cuModuleLoadDataEx img (length opt) p_opts p_vals
case s of
Success -> do
time <- peek (castPtr p_vals)
infoLog <- B.fromForeignPtr (castForeignPtr fp_ilog) 0 `fmap` c_strnlen p_ilog logSize
return $! JITResult time infoLog mdl
_ -> do
errLog <- peekCString p_elog
cudaError (unlines [describe s, errLog])
where
logSize = 2048
unpack (MaxRegisters x) = Just (JIT_MAX_REGISTERS, x)
unpack (ThreadsPerBlock x) = Just (JIT_THREADS_PER_BLOCK, x)
unpack (OptimisationLevel x) = Just (JIT_OPTIMIZATION_LEVEL, x)
unpack (Target x) = Just (JIT_TARGET, jitTargetOfCompute x)
unpack (FallbackStrategy x) = Just (JIT_FALLBACK_STRATEGY, fromEnum x)
unpack GenerateDebugInfo = Just (JIT_GENERATE_DEBUG_INFO, fromEnum True)
unpack GenerateLineInfo = Just (JIT_GENERATE_LINE_INFO, fromEnum True)
unpack Verbose = Just (JIT_LOG_VERBOSE, fromEnum True)
jitTargetOfCompute (Compute x y)
= fromEnum
$ case (x,y) of
(1,0) -> Compute10
(1,1) -> Compute11
(1,2) -> Compute12
(1,3) -> Compute13
(2,0) -> Compute20
(2,1) -> Compute21
(3,0) -> Compute30
(3,5) -> Compute35
_ -> error ("Unknown JIT Target for Compute " ++ show (Compute x y))
cuModuleLoadDataEx :: (Ptr Word8) -> (Int) -> (Ptr CInt) -> (Ptr (Ptr ())) -> IO ((Status), (Module))
cuModuleLoadDataEx a2 a3 a4 a5 =
alloca $ \a1' ->
let {a2' = castPtr a2} in
let {a3' = fromIntegral a3} in
let {a4' = id a4} in
let {a5' = id a5} in
cuModuleLoadDataEx'_ a1' a2' a3' a4' a5' >>= \res ->
let {res' = cToEnum res} in
peekMod a1'>>= \a1'' ->
return (res', a1'')
unload :: Module -> IO ()
unload !m = nothingIfOk =<< cuModuleUnload m
cuModuleUnload :: (Module) -> IO ((Status))
cuModuleUnload a1 =
let {a1' = useModule a1} in
cuModuleUnload'_ a1' >>= \res ->
let {res' = cToEnum res} in
return (res')
resultIfFound :: String -> String -> (Status, a) -> IO a
resultIfFound kind name (!status,!result) =
case status of
Success -> return result
NotFound -> cudaError (kind ++ ' ' : describe status ++ ": " ++ name)
_ -> throwIO (ExitCode status)
peekMod :: Ptr ((Ptr ())) -> IO Module
peekMod = liftM Module . peek
foreign import ccall unsafe "string.h strnlen" c_strnlen'
:: CString -> CSize -> IO CSize
c_strnlen :: CString -> Int -> IO Int
c_strnlen str maxlen = cIntConv `fmap` c_strnlen' str (cIntConv maxlen)
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleGetFunction"
cuModuleGetFunction'_ :: ((Ptr (Ptr ())) -> ((Ptr ()) -> ((Ptr CChar) -> (IO CInt))))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleGetGlobal"
cuModuleGetGlobal'_ :: ((Ptr CULLong) -> ((Ptr CULong) -> ((Ptr ()) -> ((Ptr CChar) -> (IO CInt)))))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleGetTexRef"
cuModuleGetTexRef'_ :: ((Ptr (Ptr ())) -> ((Ptr ()) -> ((Ptr CChar) -> (IO CInt))))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleLoad"
cuModuleLoad'_ :: ((Ptr (Ptr ())) -> ((Ptr CChar) -> (IO CInt)))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleLoadData"
cuModuleLoadData'_ :: ((Ptr (Ptr ())) -> ((Ptr ()) -> (IO CInt)))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleLoadDataEx"
cuModuleLoadDataEx'_ :: ((Ptr (Ptr ())) -> ((Ptr ()) -> (CUInt -> ((Ptr CInt) -> ((Ptr (Ptr ())) -> (IO CInt))))))
foreign import ccall unsafe "Foreign/CUDA/Driver/Module.chs.h cuModuleUnload"
cuModuleUnload'_ :: ((Ptr ()) -> (IO CInt))