-- GENERATED by C->Haskell Compiler, version 0.28.3 Switcheroo, 25 November 2017 (Haskell)
-- Edit the ORIGNAL .chs file instead!


{-# LINE 1 "./Foreign/NVVM/Compile.chs" #-}
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}
--------------------------------------------------------------------------------
-- |
-- Module    : Foreign.NVVM.Compile
-- Copyright : [2016] Trevor L. McDonell
-- License   : BSD
--
-- Program compilation
--
--------------------------------------------------------------------------------

module Foreign.NVVM.Compile (

  Program,
  Result(..),
  CompileOption(..),

  compileModule, compileModules,

  create,
  destroy,
  addModule, addModuleFromPtr,
  compile,
  verify

) where
import qualified Foreign.C.Types as C2HSImp
import qualified Foreign.Ptr as C2HSImp



import Foreign.CUDA.Analysis
import Foreign.NVVM.Error
import Foreign.NVVM.Internal.C2HS

import Foreign.C
import Foreign.Marshal
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable

import Control.Exception
import Data.Word
import Data.ByteString                                              ( ByteString )
import Text.Printf
import qualified Data.ByteString.Char8                              as B
import qualified Data.ByteString.Unsafe                             as B
import qualified Data.ByteString.Internal                           as B




{-# LINE 49 "./Foreign/NVVM/Compile.chs" #-}



-- | An NVVM program
--
newtype Program = Program { useProgram :: ((C2HSImp.Ptr ())) }
  deriving ( Eq, Show )

-- | The result of compiling an NVVM program.
--
data Result = Result
  { compileResult :: !ByteString  -- ^ The compiled kernel, which can be loaded into the current program using 'Foreign.CUDA.Driver.loadData*'
  , compileLog    :: !ByteString  -- ^ Warning messages generated by the compiler/verifier
  }

-- | Program compilation options
--
data CompileOption
  = OptimisationLevel !Int        -- ^ optimisation level, from 0 (disable optimisations) to 3 (default)
  | Target !Compute               -- ^ target architecture to compile for (default: compute 2.0)
  | FlushToZero                   -- ^ flush denormal values to zero when performing single-precision floating-point operations (default: no)
  | NoFMA                         -- ^ disable fused-multiply-add instructions (default: enabled)
  | FastSqrt                      -- ^ use a fast approximation for single-precision floating-point square root (default: no)
  | FastDiv                       -- ^ use a fast approximation for single-precision floating-point division and reciprocal (default: no)
  | GenerateDebugInfo             -- ^ generate debugging information (-g) (default: no)
  deriving ( Eq, Show )


-- | Compile an NVVM IR module, in either bitcode or textual representation,
-- into PTX code.
--
{-# INLINEABLE compileModule #-}
compileModule
    :: String                     -- ^ name of the module
    -> ByteString                 -- ^ NVVM IR in either textual or bitcode representation
    -> [CompileOption]            -- ^ compiler options
    -> IO Result
compileModule !name !bs !opts =
  compileModules [(name,bs)] opts


-- | Compile a collection of NVVM IR modules into PTX code
--
{-# INLINEABLE compileModules #-}
compileModules
    :: [(String, ByteString)]     -- ^ (module name, module NVVM IR) pairs to compile
    -> [CompileOption]            -- ^ compiler options
    -> IO Result
compileModules !bss !opts =
  bracket create destroy $ \prg -> do
    mapM_ (uncurry (addModule prg)) bss
    (messages, result) <- compile prg opts
    case result of
      Nothing  -> nvvmErrorIO (B.unpack messages)
      Just ptx -> return $ Result ptx messages


-- | Create an empty 'Program'
--
-- <http://docs.nvidia.com/cuda/libnvvm-api/group__compilation.html#group__compilation_1g46a0ab04a063cba28bfbb41a1939e3f4>
--
{-# INLINEABLE create #-}
create :: IO Program
create = resultIfOk =<< nvvmCreateProgram
  where
    peekProgram ptr = Program `fmap` peek ptr
    nvvmCreateProgram :: IO ((Status), (Program))
    nvvmCreateProgram =
      alloca $ \a1' -> 
      nvvmCreateProgram'_ a1' >>= \res ->
      let {res' = cToEnum res} in
      peekProgram  a1'>>= \a1'' -> 
      return (res', a1'')

{-# LINE 120 "./Foreign/NVVM/Compile.chs" #-}



-- | Destroy a 'Program'
--
-- <http://docs.nvidia.com/cuda/libnvvm-api/group__compilation.html#group__compilation_1gfba94cab1224c0152841b80690d366aa>
--
{-# INLINEABLE destroy #-}
destroy :: Program -> IO ()
destroy !prg = nothingIfOk =<< nvvmDestroyProgram prg
  where
    withProgram p = with (useProgram p)
    nvvmDestroyProgram :: (Program) -> IO ((Status))
    nvvmDestroyProgram a1 =
      withProgram a1 $ \a1' -> 
      nvvmDestroyProgram'_ a1' >>= \res ->
      let {res' = cToEnum res} in
      return (res')

{-# LINE 137 "./Foreign/NVVM/Compile.chs" #-}



-- | Add a module level NVVM IR to a program
--
-- <http://docs.nvidia.com/cuda/libnvvm-api/group__compilation.html#group__compilation_1g0c22d2b9be033c165bc37b16f3ed75c6>
--
{-# INLINEABLE addModule #-}
addModule
    :: Program          -- ^ NVVM program to add to
    -> String           -- ^ Name of the module (defaults to \"@\<unnamed\>@\" if empty)
    -> ByteString       -- ^ NVVM IR module in either bitcode or textual representation
    -> IO ()
addModule !prg !name !bs =
  B.unsafeUseAsCStringLen bs $ \(ptr,len) ->
  addModuleFromPtr prg name len (castPtr ptr)


-- | As with 'addModule', but read the specified number of bytes from the given
-- pointer.
--
{-# INLINEABLE addModuleFromPtr #-}
addModuleFromPtr
    :: Program          -- ^ NVVM program to add to
    -> String           -- ^ Name of the module (defaults to \"@\<unnamed\>@\" if empty)
    -> Int              -- ^ Number of bytes in the module
    -> Ptr Word8        -- ^ NVVM IR module in bitcode or textual representation
    -> IO ()
addModuleFromPtr !prg !name !size !buffer =
  nothingIfOk =<< nvvmAddModuleToProgram prg buffer size name
  where
    nvvmAddModuleToProgram :: (Program) -> (Ptr Word8) -> (Int) -> (String) -> IO ((Status))
    nvvmAddModuleToProgram a1 a2 a3 a4 =
      let {a1' = useProgram a1} in 
      let {a2' = castPtr a2} in 
      let {a3' = cIntConv a3} in 
      withCString a4 $ \a4' -> 
      nvvmAddModuleToProgram'_ a1' a2' a3' a4' >>= \res ->
      let {res' = cToEnum res} in
      return (res')

{-# LINE 176 "./Foreign/NVVM/Compile.chs" #-}



-- | Compile the NVVM program. Returns the log from the compiler/verifier and,
-- if successful, the compiled program.
--
-- <http://docs.nvidia.com/cuda/libnvvm-api/group__compilation.html#group__compilation_1g76ac1e23f5d0e2240e78be0e63450346>
--
{-# INLINEABLE compile #-}
compile :: Program -> [CompileOption] -> IO (ByteString, Maybe ByteString)
compile !prg !opts = do
  status    <- withCompileOptions opts (nvvmCompileProgram prg)
  messages  <- retrieve (nvvmGetProgramLogSize prg) (nvvmGetProgramLog prg)
  case status of
    Success -> do ptx <- retrieve (nvvmGetCompiledResultSize prg) (nvvmGetCompiledResult prg)
                  return (messages, Just ptx)
    _       ->    return (messages, Nothing)
  where
    nvvmCompileProgram :: (Program) -> (Int) -> (Ptr CString) -> IO ((Status))
    nvvmCompileProgram a1 a2 a3 =
      let {a1' = useProgram a1} in 
      let {a2' = cIntConv a2} in 
      let {a3' = id a3} in 
      nvvmCompileProgram'_ a1' a2' a3' >>= \res ->
      let {res' = cToEnum res} in
      return (res')

{-# LINE 200 "./Foreign/NVVM/Compile.chs" #-}


    nvvmGetCompiledResultSize :: (Program) -> IO ((Status), (Int))
    nvvmGetCompiledResultSize a1 =
      let {a1' = useProgram a1} in 
      alloca $ \a2' -> 
      nvvmGetCompiledResultSize'_ a1' a2' >>= \res ->
      let {res' = cToEnum res} in
      peekIntConv  a2'>>= \a2'' -> 
      return (res', a2'')

{-# LINE 207 "./Foreign/NVVM/Compile.chs" #-}


    nvvmGetCompiledResult :: (Program) -> (ForeignPtr Word8) -> IO ((Status))
    nvvmGetCompiledResult a1 a2 =
      let {a1' = useProgram a1} in 
      withForeignPtr' a2 $ \a2' -> 
      nvvmGetCompiledResult'_ a1' a2' >>= \res ->
      let {res' = cToEnum res} in
      return (res')

{-# LINE 214 "./Foreign/NVVM/Compile.chs" #-}



-- | Verify the NVVM program. Returns whether compilation will succeed, together
-- with any error or warning messages.
--
{-# INLINEABLE verify #-}
verify :: Program -> [CompileOption] -> IO (Status, ByteString)
verify !prg !opts = do
  status   <- withCompileOptions opts (nvvmVerifyProgram prg)
  messages <- retrieve (nvvmGetProgramLogSize prg) (nvvmGetProgramLog prg)
  return (status, messages)
  where
    nvvmVerifyProgram :: (Program) -> (Int) -> (Ptr CString) -> IO ((Status))
    nvvmVerifyProgram a1 a2 a3 =
      let {a1' = useProgram a1} in 
      let {a2' = cIntConv a2} in 
      let {a3' = id a3} in 
      nvvmVerifyProgram'_ a1' a2' a3' >>= \res ->
      let {res' = cToEnum res} in
      return (res')

{-# LINE 234 "./Foreign/NVVM/Compile.chs" #-}



nvvmGetProgramLogSize :: (Program) -> IO ((Status), (Int))
nvvmGetProgramLogSize a1 =
  let {a1' = useProgram a1} in 
  alloca $ \a2' -> 
  nvvmGetProgramLogSize'_ a1' a2' >>= \res ->
  let {res' = cToEnum res} in
  peekIntConv  a2'>>= \a2'' -> 
  return (res', a2'')

{-# LINE 242 "./Foreign/NVVM/Compile.chs" #-}


nvvmGetProgramLog :: (Program) -> (ForeignPtr Word8) -> IO ((Status))
nvvmGetProgramLog a1 a2 =
  let {a1' = useProgram a1} in 
  withForeignPtr' a2 $ \a2' -> 
  nvvmGetProgramLog'_ a1' a2' >>= \res ->
  let {res' = cToEnum res} in
  return (res')

{-# LINE 249 "./Foreign/NVVM/Compile.chs" #-}



-- Utilities
-- ---------

{-# INLINEABLE withForeignPtr' #-}
withForeignPtr' :: ForeignPtr Word8 -> (Ptr CChar -> IO a) -> IO a
withForeignPtr' fp f = withForeignPtr fp (f . castPtr)


{-# INLINEABLE withCompileOptions #-}
withCompileOptions :: [CompileOption] -> (Int -> Ptr CString -> IO a) -> IO a
withCompileOptions opts next =
  withMany withCString (map toStr opts) $ \cs -> withArrayLen cs next
  where
    toStr :: CompileOption -> String
    toStr (OptimisationLevel n)  = printf "-opt=%d" n
    toStr (Target (Compute n m)) = printf "-arch=compute_%d%d" n m
    toStr FlushToZero            = "-ftz=1"
    toStr NoFMA                  = "-fma=0"
    toStr FastSqrt               = "-prec-sqrt=0"
    toStr FastDiv                = "-prec-div=0"
    toStr GenerateDebugInfo      = "-g"

{-# INLINEABLE retrieve #-}
retrieve :: IO (Status, Int) -> (ForeignPtr Word8 -> IO Status) -> IO ByteString
retrieve size payload = do
  bytes <- resultIfOk =<< size
  if bytes <= 1                                     -- size includes NULL terminator
    then return B.empty
    else do fp <- mallocForeignPtrBytes bytes
            _  <- nothingIfOk =<< payload fp
            return (B.fromForeignPtr fp 0 bytes)


foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmCreateProgram"
  nvvmCreateProgram'_ :: ((C2HSImp.Ptr (C2HSImp.Ptr ())) -> (IO C2HSImp.CInt))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmDestroyProgram"
  nvvmDestroyProgram'_ :: ((C2HSImp.Ptr (C2HSImp.Ptr ())) -> (IO C2HSImp.CInt))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmAddModuleToProgram"
  nvvmAddModuleToProgram'_ :: ((C2HSImp.Ptr ()) -> ((C2HSImp.Ptr C2HSImp.CChar) -> (C2HSImp.CULong -> ((C2HSImp.Ptr C2HSImp.CChar) -> (IO C2HSImp.CInt)))))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmCompileProgram"
  nvvmCompileProgram'_ :: ((C2HSImp.Ptr ()) -> (C2HSImp.CInt -> ((C2HSImp.Ptr (C2HSImp.Ptr C2HSImp.CChar)) -> (IO C2HSImp.CInt))))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmGetCompiledResultSize"
  nvvmGetCompiledResultSize'_ :: ((C2HSImp.Ptr ()) -> ((C2HSImp.Ptr C2HSImp.CULong) -> (IO C2HSImp.CInt)))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmGetCompiledResult"
  nvvmGetCompiledResult'_ :: ((C2HSImp.Ptr ()) -> ((C2HSImp.Ptr C2HSImp.CChar) -> (IO C2HSImp.CInt)))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmVerifyProgram"
  nvvmVerifyProgram'_ :: ((C2HSImp.Ptr ()) -> (C2HSImp.CInt -> ((C2HSImp.Ptr (C2HSImp.Ptr C2HSImp.CChar)) -> (IO C2HSImp.CInt))))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmGetProgramLogSize"
  nvvmGetProgramLogSize'_ :: ((C2HSImp.Ptr ()) -> ((C2HSImp.Ptr C2HSImp.CULong) -> (IO C2HSImp.CInt)))

foreign import ccall unsafe "Foreign/NVVM/Compile.chs.h nvvmGetProgramLog"
  nvvmGetProgramLog'_ :: ((C2HSImp.Ptr ()) -> ((C2HSImp.Ptr C2HSImp.CChar) -> (IO C2HSImp.CInt)))