module LLVM.Internal.OrcJIT.LinkingLayer where

import LLVM.Prelude

import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.IORef
import Foreign.Ptr

import LLVM.Internal.OrcJIT
import LLVM.Internal.Coding
import LLVM.Internal.ObjectFile
import qualified LLVM.Internal.FFI.ShortByteString as SBS
import qualified LLVM.Internal.FFI.PtrHierarchy as FFI
import qualified LLVM.Internal.FFI.OrcJIT as FFI
import qualified LLVM.Internal.FFI.OrcJIT.LinkingLayer as FFI

-- | After a 'CompileLayer' has compiled the modules to object code,
-- it passes the resulting object files to a 'LinkingLayer'.
class LinkingLayer l where
  getLinkingLayer :: l -> Ptr FFI.LinkingLayer
  getCleanups :: l -> IORef [IO ()]

-- | Dispose of a 'LinkingLayer'.
disposeLinkingLayer :: LinkingLayer l => l -> IO ()
disposeLinkingLayer :: l -> IO ()
disposeLinkingLayer l :: l
l = do
  Ptr LinkingLayer -> IO ()
FFI.disposeLinkingLayer (l -> Ptr LinkingLayer
forall l. LinkingLayer l => l -> Ptr LinkingLayer
getLinkingLayer l
l)
  [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> IO [IO ()] -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IORef [IO ()] -> IO [IO ()]
forall a. IORef a -> IO a
readIORef (l -> IORef [IO ()]
forall l. LinkingLayer l => l -> IORef [IO ()]
getCleanups l
l)

-- | Add an object file to the 'LinkingLayer'.
addObjectFile :: LinkingLayer l => l -> FFI.ModuleKey -> ObjectFile -> IO ()
addObjectFile :: l -> ModuleKey -> ObjectFile -> IO ()
addObjectFile linkingLayer :: l
linkingLayer k :: ModuleKey
k (ObjectFile obj :: Ptr ObjectFile
obj) = (AnyContT IO () -> (() -> IO ()) -> IO ())
-> (() -> IO ()) -> AnyContT IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO () -> (() -> IO ()) -> IO ()
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO () -> IO ()) -> AnyContT IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Ptr (OwnerTransfered CString)
errMsg <- AnyContT IO (Ptr (OwnerTransfered CString))
forall a (m :: * -> *).
(Storable a, MonadAnyCont IO m) =>
m (Ptr a)
alloca
  IO () -> AnyContT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> AnyContT IO ()) -> IO () -> AnyContT IO ()
forall a b. (a -> b) -> a -> b
$
    Ptr LinkingLayer
-> ModuleKey
-> Ptr ObjectFile
-> Ptr (OwnerTransfered CString)
-> IO ()
FFI.addObjectFile
      (l -> Ptr LinkingLayer
forall l. LinkingLayer l => l -> Ptr LinkingLayer
getLinkingLayer l
linkingLayer)
      ModuleKey
k
      Ptr ObjectFile
obj
      Ptr (OwnerTransfered CString)
errMsg

-- | Bare bones implementation of a 'LinkingLayer'.
data ObjectLinkingLayer = ObjectLinkingLayer {
   ObjectLinkingLayer -> Ptr ObjectLinkingLayer
linkingLayer :: !(Ptr FFI.ObjectLinkingLayer),
   ObjectLinkingLayer -> IORef [IO ()]
cleanupActions :: !(IORef [IO ()])
  }

instance LinkingLayer ObjectLinkingLayer where
  getLinkingLayer :: ObjectLinkingLayer -> Ptr LinkingLayer
getLinkingLayer (ObjectLinkingLayer ptr :: Ptr ObjectLinkingLayer
ptr _) = Ptr ObjectLinkingLayer -> Ptr LinkingLayer
forall a b. DescendentOf a b => Ptr b -> Ptr a
FFI.upCast Ptr ObjectLinkingLayer
ptr
  getCleanups :: ObjectLinkingLayer -> IORef [IO ()]
getCleanups = ObjectLinkingLayer -> IORef [IO ()]
cleanupActions

-- | Create a new 'ObjectLinkingLayer'. This should be disposed using
-- 'disposeLinkingLayer' when it is no longer needed.
newObjectLinkingLayer :: ExecutionSession -> (FFI.ModuleKey -> IO (Ptr FFI.SymbolResolver)) -> IO ObjectLinkingLayer
newObjectLinkingLayer :: ExecutionSession
-> (ModuleKey -> IO (Ptr SymbolResolver)) -> IO ObjectLinkingLayer
newObjectLinkingLayer (ExecutionSession es :: Ptr ExecutionSession
es) getResolver :: ModuleKey -> IO (Ptr SymbolResolver)
getResolver = do
  IORef [IO ()]
cleanups <- IO (IORef [IO ()]) -> IO (IORef [IO ()])
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ([IO ()] -> IO (IORef [IO ()])
forall a. a -> IO (IORef a)
newIORef [])
  FunPtr (ModuleKey -> IO (Ptr SymbolResolver))
getResolver' <- IORef [IO ()]
-> IO (FunPtr (ModuleKey -> IO (Ptr SymbolResolver)))
-> IO (FunPtr (ModuleKey -> IO (Ptr SymbolResolver)))
forall a. IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr IORef [IO ()]
cleanups ((ModuleKey -> IO (Ptr SymbolResolver))
-> IO (FunPtr (ModuleKey -> IO (Ptr SymbolResolver)))
FFI.wrapGetSymbolResolver ModuleKey -> IO (Ptr SymbolResolver)
getResolver)
  Ptr ObjectLinkingLayer
linkingLayer <- Ptr ExecutionSession
-> FunPtr (ModuleKey -> IO (Ptr SymbolResolver))
-> IO (Ptr ObjectLinkingLayer)
FFI.createObjectLinkingLayer Ptr ExecutionSession
es FunPtr (ModuleKey -> IO (Ptr SymbolResolver))
getResolver'
  ObjectLinkingLayer -> IO ObjectLinkingLayer
forall (m :: * -> *) a. Monad m => a -> m a
return (ObjectLinkingLayer -> IO ObjectLinkingLayer)
-> ObjectLinkingLayer -> IO ObjectLinkingLayer
forall a b. (a -> b) -> a -> b
$ Ptr ObjectLinkingLayer -> IORef [IO ()] -> ObjectLinkingLayer
ObjectLinkingLayer Ptr ObjectLinkingLayer
linkingLayer IORef [IO ()]
cleanups

-- | 'bracket'-style wrapper around 'newObjectLinkingLayer' and 'disposeLinkingLayer'.
withObjectLinkingLayer :: ExecutionSession -> (FFI.ModuleKey -> IO (Ptr FFI.SymbolResolver)) -> (ObjectLinkingLayer -> IO a) -> IO a
withObjectLinkingLayer :: ExecutionSession
-> (ModuleKey -> IO (Ptr SymbolResolver))
-> (ObjectLinkingLayer -> IO a)
-> IO a
withObjectLinkingLayer es :: ExecutionSession
es resolver :: ModuleKey -> IO (Ptr SymbolResolver)
resolver = IO ObjectLinkingLayer
-> (ObjectLinkingLayer -> IO ())
-> (ObjectLinkingLayer -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (ExecutionSession
-> (ModuleKey -> IO (Ptr SymbolResolver)) -> IO ObjectLinkingLayer
newObjectLinkingLayer ExecutionSession
es ModuleKey -> IO (Ptr SymbolResolver)
resolver) ObjectLinkingLayer -> IO ()
forall l. LinkingLayer l => l -> IO ()
disposeLinkingLayer

-- | @'findSymbol' layer symbol exportedSymbolsOnly@ searches for
-- @symbol@ in all modules added to @layer@. If @exportedSymbolsOnly@
-- is 'True' only exported symbols are searched.
findSymbol :: LinkingLayer l => l -> ShortByteString -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbol :: l
-> ShortByteString -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbol linkingLayer :: l
linkingLayer symbol :: ShortByteString
symbol exportedSymbolsOnly :: Bool
exportedSymbolsOnly =
  ShortByteString
-> (CString -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall a. ShortByteString -> (CString -> IO a) -> IO a
SBS.useAsCString ShortByteString
symbol ((CString -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (CString -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ \symbol' :: CString
symbol' ->
    (AnyContT IO (Either JITSymbolError JITSymbol)
 -> (Either JITSymbolError JITSymbol
     -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO (Either JITSymbolError JITSymbol)
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT Either JITSymbolError JITSymbol
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO (Either JITSymbolError JITSymbol)
 -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ do
      LLVMBool
exportedSymbolsOnly' <- Bool -> AnyContT IO LLVMBool
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM Bool
exportedSymbolsOnly
      Ptr JITSymbol
symbol <- (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (Ptr JITSymbol -> IO r) -> IO r)
 -> AnyContT IO (Ptr JITSymbol))
-> (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall a b. (a -> b) -> a -> b
$ IO (Ptr JITSymbol)
-> (Ptr JITSymbol -> IO ()) -> (Ptr JITSymbol -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
        (Ptr LinkingLayer -> CString -> LLVMBool -> IO (Ptr JITSymbol)
FFI.findSymbol (l -> Ptr LinkingLayer
forall l. LinkingLayer l => l -> Ptr LinkingLayer
getLinkingLayer l
linkingLayer) CString
symbol' LLVMBool
exportedSymbolsOnly') Ptr JITSymbol -> IO ()
FFI.disposeSymbol
      Ptr JITSymbol -> AnyContT IO (Either JITSymbolError JITSymbol)
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM Ptr JITSymbol
symbol

-- | @'findSymbolIn' layer handle symbol exportedSymbolsOnly@ searches for
-- @symbol@ in the context of the module represented by @handle@. If
-- @exportedSymbolsOnly@ is 'True' only exported symbols are searched.
findSymbolIn :: LinkingLayer l => l -> FFI.ModuleKey -> ShortByteString -> Bool -> IO (Either JITSymbolError JITSymbol)
findSymbolIn :: l
-> ModuleKey
-> ShortByteString
-> Bool
-> IO (Either JITSymbolError JITSymbol)
findSymbolIn linkingLayer :: l
linkingLayer handle :: ModuleKey
handle symbol :: ShortByteString
symbol exportedSymbolsOnly :: Bool
exportedSymbolsOnly =
  ShortByteString
-> (CString -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall a. ShortByteString -> (CString -> IO a) -> IO a
SBS.useAsCString ShortByteString
symbol ((CString -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (CString -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ \symbol' :: CString
symbol' ->
    (AnyContT IO (Either JITSymbolError JITSymbol)
 -> (Either JITSymbolError JITSymbol
     -> IO (Either JITSymbolError JITSymbol))
 -> IO (Either JITSymbolError JITSymbol))
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO (Either JITSymbolError JITSymbol)
-> (Either JITSymbolError JITSymbol
    -> IO (Either JITSymbolError JITSymbol))
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT Either JITSymbolError JITSymbol
-> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO (Either JITSymbolError JITSymbol)
 -> IO (Either JITSymbolError JITSymbol))
-> AnyContT IO (Either JITSymbolError JITSymbol)
-> IO (Either JITSymbolError JITSymbol)
forall a b. (a -> b) -> a -> b
$ do
      LLVMBool
exportedSymbolsOnly' <- Bool -> AnyContT IO LLVMBool
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM Bool
exportedSymbolsOnly
      Ptr JITSymbol
symbol <- (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (Ptr JITSymbol -> IO r) -> IO r)
 -> AnyContT IO (Ptr JITSymbol))
-> (forall r. (Ptr JITSymbol -> IO r) -> IO r)
-> AnyContT IO (Ptr JITSymbol)
forall a b. (a -> b) -> a -> b
$ IO (Ptr JITSymbol)
-> (Ptr JITSymbol -> IO ()) -> (Ptr JITSymbol -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
        (Ptr LinkingLayer
-> ModuleKey -> CString -> LLVMBool -> IO (Ptr JITSymbol)
FFI.findSymbolIn (l -> Ptr LinkingLayer
forall l. LinkingLayer l => l -> Ptr LinkingLayer
getLinkingLayer l
linkingLayer) ModuleKey
handle CString
symbol' LLVMBool
exportedSymbolsOnly') Ptr JITSymbol -> IO ()
FFI.disposeSymbol
      Ptr JITSymbol -> AnyContT IO (Either JITSymbolError JITSymbol)
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM Ptr JITSymbol
symbol