module LLVM.Internal.OrcJIT.CompileLayer
  ( module LLVM.Internal.OrcJIT.CompileLayer
  , FFI.ModuleKey
  ) where

import LLVM.Prelude

import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.IORef
import Foreign.Ptr

import LLVM.Internal.Coding
import qualified LLVM.Internal.FFI.DataLayout as FFI
import qualified LLVM.Internal.FFI.OrcJIT as FFI
import qualified LLVM.Internal.FFI.OrcJIT.CompileLayer as FFI
import LLVM.Internal.Module hiding (getDataLayout)
import LLVM.Internal.OrcJIT

-- | There are two main types of operations provided by instances of 'CompileLayer'.
--
-- 1. You can add \/ remove modules using 'addModule' \/ 'removeModuleSet'.
--
-- 2. You can search for symbols using 'findSymbol' \/ 'findSymbolIn' in
-- the previously added modules.
class CompileLayer l where
  getCompileLayer :: l -> Ptr FFI.CompileLayer
  getDataLayout :: l -> Ptr FFI.DataLayout
  getCleanups :: l -> IORef [IO ()]

-- | Mangle a symbol according to the data layout stored in the
-- 'CompileLayer'.
mangleSymbol :: CompileLayer l => l -> ShortByteString -> IO MangledSymbol
mangleSymbol :: l -> ShortByteString -> IO MangledSymbol
mangleSymbol compileLayer :: l
compileLayer symbol :: ShortByteString
symbol = (AnyContT IO MangledSymbol
 -> (MangledSymbol -> IO MangledSymbol) -> IO MangledSymbol)
-> (MangledSymbol -> IO MangledSymbol)
-> AnyContT IO MangledSymbol
-> IO MangledSymbol
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO MangledSymbol
-> (MangledSymbol -> IO MangledSymbol) -> IO MangledSymbol
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT MangledSymbol -> IO MangledSymbol
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO MangledSymbol -> IO MangledSymbol)
-> AnyContT IO MangledSymbol -> IO MangledSymbol
forall a b. (a -> b) -> a -> b
$ do
  Ptr CString
mangledSymbol <- AnyContT IO (Ptr CString)
forall a (m :: * -> *).
(Storable a, MonadAnyCont IO m) =>
m (Ptr a)
alloca
  CString
symbol' <- ShortByteString -> AnyContT IO CString
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM ShortByteString
symbol
  (forall r. (() -> IO r) -> IO r) -> AnyContT IO ()
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (() -> IO r) -> IO r) -> AnyContT IO ())
-> (forall r. (() -> IO r) -> IO r) -> AnyContT IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> (() -> IO ()) -> (() -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (Ptr CString -> CString -> Ptr DataLayout -> IO ()
FFI.getMangledSymbol Ptr CString
mangledSymbol CString
symbol' (l -> Ptr DataLayout
forall l. CompileLayer l => l -> Ptr DataLayout
getDataLayout l
compileLayer))
    (\_ -> CString -> IO ()
FFI.disposeMangledSymbol (CString -> IO ()) -> IO CString -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CString -> IO CString
forall a (m :: * -> *). (Storable a, MonadIO m) => Ptr a -> m a
peek Ptr CString
mangledSymbol)
  CString -> AnyContT IO MangledSymbol
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM (CString -> AnyContT IO MangledSymbol)
-> AnyContT IO CString -> AnyContT IO MangledSymbol
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CString -> AnyContT IO CString
forall a (m :: * -> *). (Storable a, MonadIO m) => Ptr a -> m a
peek Ptr CString
mangledSymbol

-- | @'findSymbol' layer symbol exportedSymbolsOnly@ searches for
-- @symbol@ in all modules added to @layer@. If @exportedSymbolsOnly@
-- is 'True' only exported symbols are searched.
findSymbol :: CompileLayer l => l -> MangledSymbol -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbol :: l -> MangledSymbol -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbol compileLayer :: l
compileLayer symbol :: MangledSymbol
symbol exportedSymbolsOnly :: Bool
exportedSymbolsOnly = (AnyContT IO (Either JITSymbolError JITSymbol)
 -> (Either JITSymbolError JITSymbol
     -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO (Either JITSymbolError JITSymbol)
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT Either JITSymbolError JITSymbol
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO (Either JITSymbolError JITSymbol)
 -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ do
  CString
symbol' <- MangledSymbol -> AnyContT IO CString
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM MangledSymbol
symbol
  LLVMBool
exportedSymbolsOnly' <- Bool -> AnyContT IO LLVMBool
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM Bool
exportedSymbolsOnly
  Ptr JITSymbol
symbol <- (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (Ptr JITSymbol -> IO r) -> IO r)
 -> AnyContT IO (Ptr JITSymbol))
-> (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall a b. (a -> b) -> a -> b
$ IO (Ptr JITSymbol)
-> (Ptr JITSymbol -> IO ()) -> (Ptr JITSymbol -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (Ptr CompileLayer -> CString -> LLVMBool -> IO (Ptr JITSymbol)
FFI.findSymbol (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
compileLayer) CString
symbol' LLVMBool
exportedSymbolsOnly') Ptr JITSymbol -> IO ()
FFI.disposeSymbol
  Ptr JITSymbol -> AnyContT IO (Either JITSymbolError JITSymbol)
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM Ptr JITSymbol
symbol

-- | @'findSymbolIn' layer handle symbol exportedSymbolsOnly@ searches for
-- @symbol@ in the context of the module represented by @handle@. If
-- @exportedSymbolsOnly@ is 'True' only exported symbols are searched.
findSymbolIn :: CompileLayer l => l -> FFI.ModuleKey -> MangledSymbol -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbolIn :: l
-> ModuleKey
-> MangledSymbol
-> Bool
-> IO (Either JITSymbolError JITSymbol)
findSymbolIn compileLayer :: l
compileLayer handle :: ModuleKey
handle symbol :: MangledSymbol
symbol exportedSymbolsOnly :: Bool
exportedSymbolsOnly = (AnyContT IO (Either JITSymbolError JITSymbol)
 -> (Either JITSymbolError JITSymbol
     -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO (Either JITSymbolError JITSymbol)
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT Either JITSymbolError JITSymbol
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO (Either JITSymbolError JITSymbol)
 -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ do
  CString
symbol' <- MangledSymbol -> AnyContT IO CString
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM MangledSymbol
symbol
  LLVMBool
exportedSymbolsOnly' <- Bool -> AnyContT IO LLVMBool
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM Bool
exportedSymbolsOnly
  Ptr JITSymbol
symbol <- (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (Ptr JITSymbol -> IO r) -> IO r)
 -> AnyContT IO (Ptr JITSymbol))
-> (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall a b. (a -> b) -> a -> b
$ IO (Ptr JITSymbol)
-> (Ptr JITSymbol -> IO ()) -> (Ptr JITSymbol -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (Ptr CompileLayer
-> ModuleKey -> CString -> LLVMBool -> IO (Ptr JITSymbol)
FFI.findSymbolIn (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
compileLayer) ModuleKey
handle CString
symbol' LLVMBool
exportedSymbolsOnly') Ptr JITSymbol -> IO ()
FFI.disposeSymbol
  Ptr JITSymbol -> AnyContT IO (Either JITSymbolError JITSymbol)
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM Ptr JITSymbol
symbol

-- | Add a module to the 'CompileLayer'. The 'SymbolResolver' is used
-- to resolve external symbols in the module.
--
-- /Note:/ This function consumes the module passed to it and it must
-- not be used after calling this method.
addModule :: CompileLayer l => l -> FFI.ModuleKey -> Module -> IO ()
addModule :: l -> ModuleKey -> Module -> IO ()
addModule compileLayer :: l
compileLayer k :: ModuleKey
k mod :: Module
mod = (AnyContT IO () -> (() -> IO ()) -> IO ())
-> (() -> IO ()) -> AnyContT IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO () -> (() -> IO ()) -> IO ()
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO () -> IO ()) -> AnyContT IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Ptr Module
mod' <- IO (Ptr Module) -> AnyContT IO (Ptr Module)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr Module) -> AnyContT IO (Ptr Module))
-> IO (Ptr Module) -> AnyContT IO (Ptr Module)
forall a b. (a -> b) -> a -> b
$ Module -> IO (Ptr Module)
forall (m :: * -> *). MonadIO m => Module -> m (Ptr Module)
readModule Module
mod
  IO () -> AnyContT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> AnyContT IO ()) -> IO () -> AnyContT IO ()
forall a b. (a -> b) -> a -> b
$ Module -> IO ()
deleteModule Module
mod
  Ptr (OwnerTransfered CString)
errMsg <- AnyContT IO (Ptr (OwnerTransfered CString))
forall a (m :: * -> *).
(Storable a, MonadAnyCont IO m) =>
m (Ptr a)
alloca
  IO () -> AnyContT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> AnyContT IO ()) -> IO () -> AnyContT IO ()
forall a b. (a -> b) -> a -> b
$
    Ptr CompileLayer
-> Ptr DataLayout
-> ModuleKey
-> Ptr Module
-> Ptr (OwnerTransfered CString)
-> IO ()
FFI.addModule
      (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
compileLayer)
      (l -> Ptr DataLayout
forall l. CompileLayer l => l -> Ptr DataLayout
getDataLayout l
compileLayer)
      ModuleKey
k
      Ptr Module
mod'
      Ptr (OwnerTransfered CString)
errMsg

-- | Remove a previously added module.
removeModule :: CompileLayer l => l -> FFI.ModuleKey -> IO ()
removeModule :: l -> ModuleKey -> IO ()
removeModule compileLayer :: l
compileLayer handle :: ModuleKey
handle =
  Ptr CompileLayer -> ModuleKey -> IO ()
FFI.removeModule (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
compileLayer) ModuleKey
handle

-- | 'bracket'-style wrapper around 'addModule' and 'removeModule'.
--
-- /Note:/ This function consumes the module passed to it and it must
-- not be used after calling this method.
withModule :: CompileLayer l => l -> FFI.ModuleKey -> Module -> IO a -> IO a
withModule :: l -> ModuleKey -> Module -> IO a -> IO a
withModule compileLayer :: l
compileLayer k :: ModuleKey
k mod :: Module
mod =
  IO () -> IO () -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_
    (l -> ModuleKey -> Module -> IO ()
forall l. CompileLayer l => l -> ModuleKey -> Module -> IO ()
addModule l
compileLayer ModuleKey
k Module
mod)
    (l -> ModuleKey -> IO ()
forall l. CompileLayer l => l -> ModuleKey -> IO ()
removeModule l
compileLayer ModuleKey
k)

-- | Dispose of a 'CompileLayer'. This should called when the
-- 'CompileLayer' is not needed anymore.
disposeCompileLayer :: CompileLayer l => l -> IO ()
disposeCompileLayer :: l -> IO ()
disposeCompileLayer l :: l
l = do
  Ptr CompileLayer -> IO ()
FFI.disposeCompileLayer (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
l)
  [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> IO [IO ()] -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IORef [IO ()] -> IO [IO ()]
forall a. IORef a -> IO a
readIORef (l -> IORef [IO ()]
forall l. CompileLayer l => l -> IORef [IO ()]
getCleanups l
l)