{-# LANGUAGE BangPatterns #-} {-# OPTIONS_GHC -funbox-strict-fields #-} -------------------------------------------------------------------------------- -- | -- Module : Foreign.NVVM.Compile -- Copyright : [2016] Trevor L. McDonell -- License : BSD -- -- Program compilation -- -------------------------------------------------------------------------------- module Foreign.NVVM.Compile ( Program, Result(..), CompileOption(..), compileModule, compileModules, create, destroy, addModule, addModuleFromPtr, compile, verify ) where import Foreign.CUDA.Analysis import Foreign.NVVM.Error import Foreign.NVVM.Internal.C2HS import Foreign.C import Foreign.Marshal import Foreign.Ptr import Foreign.ForeignPtr import Foreign.Storable import Control.Exception import Data.Word import Data.ByteString ( ByteString ) import Text.Printf import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Unsafe as B import qualified Data.ByteString.Internal as B #include "cbits/stubs.h" {# context lib="nvvm" #} -- | An NVVM program -- newtype Program = Program { useProgram :: {# type nvvmProgram #} } deriving ( Eq, Show ) -- | The result of compiling an NVVM program. -- data Result = Result { compileResult :: !ByteString -- ^ The compiled kernel, which can be loaded into the current program using 'Foreign.CUDA.Driver.loadData*' , compileLog :: !ByteString -- ^ Warning messages generated by the compiler/verifier } -- | Program compilation options -- data CompileOption = OptimisationLevel !Int -- ^ optimisation level, from 0 (disable optimisations) to 3 (default) | Target !Compute -- ^ target architecture to compile for (default: compute 2.0) | FlushToZero -- ^ flush denormal values to zero when performing single-precision floating-point operations (default: no) | NoFMA -- ^ disable fused-multiply-add instructions (default: enabled) | FastSqrt -- ^ use a fast approximation for single-precision floating-point square root (default: no) | FastDiv -- ^ use a fast approximation for single-precision floating-point division and reciprocal (default: no) | GenerateDebugInfo -- ^ generate debugging information (-g) (default: no) deriving ( Eq, Show ) -- | Compile an NVVM IR module, in either bitcode or textual representation, -- into PTX code. -- {-# INLINEABLE compileModule #-} compileModule :: String -- ^ name of the module -> ByteString -- ^ NVVM IR in either textual or bitcode representation -> [CompileOption] -- ^ compiler options -> IO Result compileModule !name !bs !opts = compileModules [(name,bs)] opts -- | Compile a collection of NVVM IR modules into PTX code -- {-# INLINEABLE compileModules #-} compileModules :: [(String, ByteString)] -- ^ (module name, module NVVM IR) pairs to compile -> [CompileOption] -- ^ compiler options -> IO Result compileModules !bss !opts = bracket create destroy $ \prg -> do mapM_ (uncurry (addModule prg)) bss (messages, result) <- compile prg opts case result of Nothing -> nvvmErrorIO (B.unpack messages) Just ptx -> return $ Result ptx messages -- | Create an empty 'Program' -- -- -- {-# INLINEABLE create #-} create :: IO Program create = resultIfOk =<< nvvmCreateProgram where peekProgram ptr = Program `fmap` peek ptr {# fun unsafe nvvmCreateProgram { alloca- `Program' peekProgram* } -> `Status' cToEnum #} -- | Destroy a 'Program' -- -- -- {-# INLINEABLE destroy #-} destroy :: Program -> IO () destroy !prg = nothingIfOk =<< nvvmDestroyProgram prg where withProgram p = with (useProgram p) {# fun unsafe nvvmDestroyProgram { withProgram* `Program' } -> `Status' cToEnum #} -- | Add a module level NVVM IR to a program -- -- -- {-# INLINEABLE addModule #-} addModule :: Program -- ^ NVVM program to add to -> String -- ^ Name of the module (defaults to "" if empty) -> ByteString -- ^ NVVM IR module in either bitcode or textual representation -> IO () addModule !prg !name !bs = B.unsafeUseAsCStringLen bs $ \(ptr,len) -> addModuleFromPtr prg name len (castPtr ptr) -- | As with 'addModule', but read the specified number of bytes from the given -- pointer. -- {-# INLINEABLE addModuleFromPtr #-} addModuleFromPtr :: Program -- ^ NVVM program to add to -> String -- ^ Name of the module (defaults to "" if empty) -> Int -- ^ Number of bytes in the module -> Ptr Word8 -- ^ NVVM IR module in bitcode or textual representation -> IO () addModuleFromPtr !prg !name !size !buffer = nothingIfOk =<< nvvmAddModuleToProgram prg buffer size name where {# fun unsafe nvvmAddModuleToProgram { useProgram `Program' , castPtr `Ptr Word8' , cIntConv `Int' , withCString* `String' } -> `Status' cToEnum #} -- | Compile the NVVM program. Returns the log from the compiler/verifier and, -- if successful, the compiled program. -- -- -- {-# INLINEABLE compile #-} compile :: Program -> [CompileOption] -> IO (ByteString, Maybe ByteString) compile !prg !opts = do status <- withCompileOptions opts (nvvmCompileProgram prg) messages <- retrieve (nvvmGetProgramLogSize prg) (nvvmGetProgramLog prg) case status of Success -> do ptx <- retrieve (nvvmGetCompiledResultSize prg) (nvvmGetCompiledResult prg) return (messages, Just ptx) _ -> return (messages, Nothing) where {# fun unsafe nvvmCompileProgram { useProgram `Program' , cIntConv `Int' , id `Ptr CString' } -> `Status' cToEnum #} {# fun unsafe nvvmGetCompiledResultSize { useProgram `Program' , alloca- `Int' peekIntConv* } -> `Status' cToEnum #} {# fun unsafe nvvmGetCompiledResult { useProgram `Program' , withForeignPtr'* `ForeignPtr Word8' } -> `Status' cToEnum #} -- | Verify the NVVM program. Returns whether compilation will succeed, together -- with any error or warning messages. -- {-# INLINEABLE verify #-} verify :: Program -> [CompileOption] -> IO (Status, ByteString) verify !prg !opts = do status <- withCompileOptions opts (nvvmVerifyProgram prg) messages <- retrieve (nvvmGetProgramLogSize prg) (nvvmGetProgramLog prg) return (status, messages) where {# fun unsafe nvvmVerifyProgram { useProgram `Program' , cIntConv `Int' , id `Ptr CString' } -> `Status' cToEnum #} {# fun unsafe nvvmGetProgramLogSize { useProgram `Program' , alloca- `Int' peekIntConv* } -> `Status' cToEnum #} {# fun unsafe nvvmGetProgramLog { useProgram `Program' , withForeignPtr'* `ForeignPtr Word8' } -> `Status' cToEnum #} -- Utilities -- --------- {-# INLINEABLE withForeignPtr' #-} withForeignPtr' :: ForeignPtr Word8 -> (Ptr CChar -> IO a) -> IO a withForeignPtr' fp f = withForeignPtr fp (f . castPtr) {-# INLINEABLE withCompileOptions #-} withCompileOptions :: [CompileOption] -> (Int -> Ptr CString -> IO a) -> IO a withCompileOptions opts next = withMany withCString (map toStr opts) $ \cs -> withArrayLen cs next where toStr :: CompileOption -> String toStr (OptimisationLevel n) = printf "-opt=%d" n toStr (Target (Compute n m)) = printf "-arch=compute_%d%d" n m toStr FlushToZero = "-ftz=1" toStr NoFMA = "-fma=0" toStr FastSqrt = "-prec-sqrt=0" toStr FastDiv = "-prec-div=0" toStr GenerateDebugInfo = "-g" {-# INLINEABLE retrieve #-} retrieve :: IO (Status, Int) -> (ForeignPtr Word8 -> IO Status) -> IO ByteString retrieve size payload = do bytes <- resultIfOk =<< size if bytes <= 1 -- size includes NULL terminator then return B.empty else do fp <- mallocForeignPtrBytes bytes _ <- nothingIfOk =<< payload fp return (B.fromForeignPtr fp 0 bytes)