{-# LANGUAGE MultiParamTypeClasses #-}
module LLVM.Internal.OrcJIT where

import LLVM.Prelude

import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.Bits
import Data.ByteString (packCString, useAsCString)
import Data.IORef
import Foreign.C.String
import Foreign.Ptr

import LLVM.Internal.Coding
import LLVM.Internal.Target
import qualified LLVM.Internal.FFI.DataLayout as FFI
import qualified LLVM.Internal.FFI.LLVMCTypes as FFI
import qualified LLVM.Internal.FFI.OrcJIT as FFI
import qualified LLVM.Internal.FFI.Target as FFI

-- | A mangled symbol which can be used in 'findSymbol'. This can be
-- created using 'mangleSymbol'.
newtype MangledSymbol = MangledSymbol ByteString
  deriving (Int -> MangledSymbol -> ShowS
[MangledSymbol] -> ShowS
MangledSymbol -> String
(Int -> MangledSymbol -> ShowS)
-> (MangledSymbol -> String)
-> ([MangledSymbol] -> ShowS)
-> Show MangledSymbol
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MangledSymbol] -> ShowS
$cshowList :: [MangledSymbol] -> ShowS
show :: MangledSymbol -> String
$cshow :: MangledSymbol -> String
showsPrec :: Int -> MangledSymbol -> ShowS
$cshowsPrec :: Int -> MangledSymbol -> ShowS
Show, MangledSymbol -> MangledSymbol -> Bool
(MangledSymbol -> MangledSymbol -> Bool)
-> (MangledSymbol -> MangledSymbol -> Bool) -> Eq MangledSymbol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MangledSymbol -> MangledSymbol -> Bool
$c/= :: MangledSymbol -> MangledSymbol -> Bool
== :: MangledSymbol -> MangledSymbol -> Bool
$c== :: MangledSymbol -> MangledSymbol -> Bool
Eq, Eq MangledSymbol
Eq MangledSymbol =>
(MangledSymbol -> MangledSymbol -> Ordering)
-> (MangledSymbol -> MangledSymbol -> Bool)
-> (MangledSymbol -> MangledSymbol -> Bool)
-> (MangledSymbol -> MangledSymbol -> Bool)
-> (MangledSymbol -> MangledSymbol -> Bool)
-> (MangledSymbol -> MangledSymbol -> MangledSymbol)
-> (MangledSymbol -> MangledSymbol -> MangledSymbol)
-> Ord MangledSymbol
MangledSymbol -> MangledSymbol -> Bool
MangledSymbol -> MangledSymbol -> Ordering
MangledSymbol -> MangledSymbol -> MangledSymbol
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MangledSymbol -> MangledSymbol -> MangledSymbol
$cmin :: MangledSymbol -> MangledSymbol -> MangledSymbol
max :: MangledSymbol -> MangledSymbol -> MangledSymbol
$cmax :: MangledSymbol -> MangledSymbol -> MangledSymbol
>= :: MangledSymbol -> MangledSymbol -> Bool
$c>= :: MangledSymbol -> MangledSymbol -> Bool
> :: MangledSymbol -> MangledSymbol -> Bool
$c> :: MangledSymbol -> MangledSymbol -> Bool
<= :: MangledSymbol -> MangledSymbol -> Bool
$c<= :: MangledSymbol -> MangledSymbol -> Bool
< :: MangledSymbol -> MangledSymbol -> Bool
$c< :: MangledSymbol -> MangledSymbol -> Bool
compare :: MangledSymbol -> MangledSymbol -> Ordering
$ccompare :: MangledSymbol -> MangledSymbol -> Ordering
$cp1Ord :: Eq MangledSymbol
Ord)

instance EncodeM (AnyContT IO) MangledSymbol CString where
  encodeM :: MangledSymbol -> AnyContT IO CString
encodeM (MangledSymbol bs :: ByteString
bs) = (forall r. (CString -> IO r) -> IO r) -> AnyContT IO CString
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (CString -> IO r) -> IO r) -> AnyContT IO CString)
-> (forall r. (CString -> IO r) -> IO r) -> AnyContT IO CString
forall a b. (a -> b) -> a -> b
$ ByteString -> (CString -> IO r) -> IO r
forall a. ByteString -> (CString -> IO a) -> IO a
useAsCString ByteString
bs

instance MonadIO m => DecodeM m MangledSymbol CString where
  decodeM :: CString -> m MangledSymbol
decodeM str :: CString
str = IO MangledSymbol -> m MangledSymbol
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO MangledSymbol -> m MangledSymbol)
-> IO MangledSymbol -> m MangledSymbol
forall a b. (a -> b) -> a -> b
$ ByteString -> MangledSymbol
MangledSymbol (ByteString -> MangledSymbol) -> IO ByteString -> IO MangledSymbol
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> IO ByteString
packCString CString
str

newtype ExecutionSession = ExecutionSession (Ptr FFI.ExecutionSession)

-- | Contrary to the C++ interface, we do not store the HasError flag
-- here. Instead decoding a JITSymbol produces a sumtype based on
-- whether that flag is set or not.
data JITSymbolFlags =
  JITSymbolFlags {
    JITSymbolFlags -> Bool
jitSymbolWeak :: !Bool -- ^ Is this a weak symbol?
  , JITSymbolFlags -> Bool
jitSymbolCommon :: !Bool -- ^ Is this a common symbol?
  , JITSymbolFlags -> Bool
jitSymbolAbsolute :: !Bool
    -- ^ Is this an absolute symbol? This will cause LLVM to use
    -- absolute relocations for the symbol even in position
    -- independent code.
  , JITSymbolFlags -> Bool
jitSymbolExported :: !Bool -- ^ Is this symbol exported?
  }
  deriving (Int -> JITSymbolFlags -> ShowS
[JITSymbolFlags] -> ShowS
JITSymbolFlags -> String
(Int -> JITSymbolFlags -> ShowS)
-> (JITSymbolFlags -> String)
-> ([JITSymbolFlags] -> ShowS)
-> Show JITSymbolFlags
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JITSymbolFlags] -> ShowS
$cshowList :: [JITSymbolFlags] -> ShowS
show :: JITSymbolFlags -> String
$cshow :: JITSymbolFlags -> String
showsPrec :: Int -> JITSymbolFlags -> ShowS
$cshowsPrec :: Int -> JITSymbolFlags -> ShowS
Show, JITSymbolFlags -> JITSymbolFlags -> Bool
(JITSymbolFlags -> JITSymbolFlags -> Bool)
-> (JITSymbolFlags -> JITSymbolFlags -> Bool) -> Eq JITSymbolFlags
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c/= :: JITSymbolFlags -> JITSymbolFlags -> Bool
== :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c== :: JITSymbolFlags -> JITSymbolFlags -> Bool
Eq, Eq JITSymbolFlags
Eq JITSymbolFlags =>
(JITSymbolFlags -> JITSymbolFlags -> Ordering)
-> (JITSymbolFlags -> JITSymbolFlags -> Bool)
-> (JITSymbolFlags -> JITSymbolFlags -> Bool)
-> (JITSymbolFlags -> JITSymbolFlags -> Bool)
-> (JITSymbolFlags -> JITSymbolFlags -> Bool)
-> (JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags)
-> (JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags)
-> Ord JITSymbolFlags
JITSymbolFlags -> JITSymbolFlags -> Bool
JITSymbolFlags -> JITSymbolFlags -> Ordering
JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
$cmin :: JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
max :: JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
$cmax :: JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
>= :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c>= :: JITSymbolFlags -> JITSymbolFlags -> Bool
> :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c> :: JITSymbolFlags -> JITSymbolFlags -> Bool
<= :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c<= :: JITSymbolFlags -> JITSymbolFlags -> Bool
< :: JITSymbolFlags -> JITSymbolFlags -> Bool
$c< :: JITSymbolFlags -> JITSymbolFlags -> Bool
compare :: JITSymbolFlags -> JITSymbolFlags -> Ordering
$ccompare :: JITSymbolFlags -> JITSymbolFlags -> Ordering
$cp1Ord :: Eq JITSymbolFlags
Ord)

defaultJITSymbolFlags :: JITSymbolFlags
defaultJITSymbolFlags :: JITSymbolFlags
defaultJITSymbolFlags = Bool -> Bool -> Bool -> Bool -> JITSymbolFlags
JITSymbolFlags Bool
False Bool
False Bool
False Bool
False

data JITSymbol =
  JITSymbol {
    JITSymbol -> WordPtr
jitSymbolAddress :: !WordPtr, -- ^ The address of the symbol. If
                                  -- you’ve looked up a function, you
                                  -- need to cast this to a 'FunPtr'.
    JITSymbol -> JITSymbolFlags
jitSymbolFlags :: !JITSymbolFlags -- ^ The flags of this symbol.
  }
  deriving (Int -> JITSymbol -> ShowS
[JITSymbol] -> ShowS
JITSymbol -> String
(Int -> JITSymbol -> ShowS)
-> (JITSymbol -> String)
-> ([JITSymbol] -> ShowS)
-> Show JITSymbol
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JITSymbol] -> ShowS
$cshowList :: [JITSymbol] -> ShowS
show :: JITSymbol -> String
$cshow :: JITSymbol -> String
showsPrec :: Int -> JITSymbol -> ShowS
$cshowsPrec :: Int -> JITSymbol -> ShowS
Show, JITSymbol -> JITSymbol -> Bool
(JITSymbol -> JITSymbol -> Bool)
-> (JITSymbol -> JITSymbol -> Bool) -> Eq JITSymbol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JITSymbol -> JITSymbol -> Bool
$c/= :: JITSymbol -> JITSymbol -> Bool
== :: JITSymbol -> JITSymbol -> Bool
$c== :: JITSymbol -> JITSymbol -> Bool
Eq, Eq JITSymbol
Eq JITSymbol =>
(JITSymbol -> JITSymbol -> Ordering)
-> (JITSymbol -> JITSymbol -> Bool)
-> (JITSymbol -> JITSymbol -> Bool)
-> (JITSymbol -> JITSymbol -> Bool)
-> (JITSymbol -> JITSymbol -> Bool)
-> (JITSymbol -> JITSymbol -> JITSymbol)
-> (JITSymbol -> JITSymbol -> JITSymbol)
-> Ord JITSymbol
JITSymbol -> JITSymbol -> Bool
JITSymbol -> JITSymbol -> Ordering
JITSymbol -> JITSymbol -> JITSymbol
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: JITSymbol -> JITSymbol -> JITSymbol
$cmin :: JITSymbol -> JITSymbol -> JITSymbol
max :: JITSymbol -> JITSymbol -> JITSymbol
$cmax :: JITSymbol -> JITSymbol -> JITSymbol
>= :: JITSymbol -> JITSymbol -> Bool
$c>= :: JITSymbol -> JITSymbol -> Bool
> :: JITSymbol -> JITSymbol -> Bool
$c> :: JITSymbol -> JITSymbol -> Bool
<= :: JITSymbol -> JITSymbol -> Bool
$c<= :: JITSymbol -> JITSymbol -> Bool
< :: JITSymbol -> JITSymbol -> Bool
$c< :: JITSymbol -> JITSymbol -> Bool
compare :: JITSymbol -> JITSymbol -> Ordering
$ccompare :: JITSymbol -> JITSymbol -> Ordering
$cp1Ord :: Eq JITSymbol
Ord)

data JITSymbolError = JITSymbolError ShortByteString
  deriving (Int -> JITSymbolError -> ShowS
[JITSymbolError] -> ShowS
JITSymbolError -> String
(Int -> JITSymbolError -> ShowS)
-> (JITSymbolError -> String)
-> ([JITSymbolError] -> ShowS)
-> Show JITSymbolError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JITSymbolError] -> ShowS
$cshowList :: [JITSymbolError] -> ShowS
show :: JITSymbolError -> String
$cshow :: JITSymbolError -> String
showsPrec :: Int -> JITSymbolError -> ShowS
$cshowsPrec :: Int -> JITSymbolError -> ShowS
Show, JITSymbolError -> JITSymbolError -> Bool
(JITSymbolError -> JITSymbolError -> Bool)
-> (JITSymbolError -> JITSymbolError -> Bool) -> Eq JITSymbolError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JITSymbolError -> JITSymbolError -> Bool
$c/= :: JITSymbolError -> JITSymbolError -> Bool
== :: JITSymbolError -> JITSymbolError -> Bool
$c== :: JITSymbolError -> JITSymbolError -> Bool
Eq)

-- | Specifies how external symbols in a module added to a
-- 'CompileLayer' should be resolved.
newtype SymbolResolver =
  SymbolResolver (MangledSymbol -> IO (Either JITSymbolError JITSymbol))

-- | Create a `FFI.SymbolResolver` that can be used with the JIT.
withSymbolResolver :: ExecutionSession -> SymbolResolver -> (Ptr FFI.SymbolResolver -> IO a) -> IO a
withSymbolResolver :: ExecutionSession
-> SymbolResolver -> (Ptr SymbolResolver -> IO a) -> IO a
withSymbolResolver (ExecutionSession es :: Ptr ExecutionSession
es) (SymbolResolver resolverFn :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
resolverFn) f :: Ptr SymbolResolver -> IO a
f =
  IO (FunPtr SymbolResolverFn)
-> (FunPtr SymbolResolverFn -> IO ())
-> (FunPtr SymbolResolverFn -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SymbolResolverFn -> IO (FunPtr SymbolResolverFn)
FFI.wrapSymbolResolverFn SymbolResolverFn
forall t b c.
(EncodeM IO (Either JITSymbolError JITSymbol) (t -> IO b),
 DecodeM IO MangledSymbol c) =>
c -> t -> IO b
resolverFn') FunPtr SymbolResolverFn -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr ((FunPtr SymbolResolverFn -> IO a) -> IO a)
-> (FunPtr SymbolResolverFn -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \resolverPtr :: FunPtr SymbolResolverFn
resolverPtr ->
    IO (Ptr SymbolResolver)
-> (Ptr SymbolResolver -> IO ())
-> (Ptr SymbolResolver -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Ptr ExecutionSession
-> FunPtr SymbolResolverFn -> IO (Ptr SymbolResolver)
FFI.createLambdaResolver Ptr ExecutionSession
es FunPtr SymbolResolverFn
resolverPtr) Ptr SymbolResolver -> IO ()
FFI.disposeSymbolResolver ((Ptr SymbolResolver -> IO a) -> IO a)
-> (Ptr SymbolResolver -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \resolver :: Ptr SymbolResolver
resolver ->
      Ptr SymbolResolver -> IO a
f Ptr SymbolResolver
resolver
  where
    resolverFn' :: c -> t -> IO b
resolverFn' symbol :: c
symbol result :: t
result = do
      t -> IO b
setSymbol <- Either JITSymbolError JITSymbol -> IO (t -> IO b)
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM (Either JITSymbolError JITSymbol -> IO (t -> IO b))
-> IO (Either JITSymbolError JITSymbol) -> IO (t -> IO b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MangledSymbol -> IO (Either JITSymbolError JITSymbol)
resolverFn (MangledSymbol -> IO (Either JITSymbolError JITSymbol))
-> IO MangledSymbol -> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< c -> IO MangledSymbol
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM c
symbol
      t -> IO b
setSymbol t
result

instance Monad m => EncodeM m JITSymbolFlags FFI.JITSymbolFlags where
  encodeM :: JITSymbolFlags -> m JITSymbolFlags
encodeM f :: JITSymbolFlags
f = JITSymbolFlags -> m JITSymbolFlags
forall (m :: * -> *) a. Monad m => a -> m a
return (JITSymbolFlags -> m JITSymbolFlags)
-> JITSymbolFlags -> m JITSymbolFlags
forall a b. (a -> b) -> a -> b
$ (JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags)
-> [JITSymbolFlags] -> JITSymbolFlags
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
(.|.) [
      if JITSymbolFlags -> Bool
a JITSymbolFlags
f
         then JITSymbolFlags
b
         else 0
    | (a :: JITSymbolFlags -> Bool
a,b :: JITSymbolFlags
b) <- [
          (JITSymbolFlags -> Bool
jitSymbolWeak, JITSymbolFlags
FFI.jitSymbolFlagsWeak),
          (JITSymbolFlags -> Bool
jitSymbolCommon, JITSymbolFlags
FFI.jitSymbolFlagsCommon),
          (JITSymbolFlags -> Bool
jitSymbolAbsolute, JITSymbolFlags
FFI.jitSymbolFlagsAbsolute),
          (JITSymbolFlags -> Bool
jitSymbolExported, JITSymbolFlags
FFI.jitSymbolFlagsExported)
        ]
    ]

instance Monad m => DecodeM m JITSymbolFlags FFI.JITSymbolFlags where
  decodeM :: JITSymbolFlags -> m JITSymbolFlags
decodeM f :: JITSymbolFlags
f =
    JITSymbolFlags -> m JITSymbolFlags
forall (m :: * -> *) a. Monad m => a -> m a
return (JITSymbolFlags -> m JITSymbolFlags)
-> JITSymbolFlags -> m JITSymbolFlags
forall a b. (a -> b) -> a -> b
$ $WJITSymbolFlags :: Bool -> Bool -> Bool -> Bool -> JITSymbolFlags
JITSymbolFlags {
      jitSymbolWeak :: Bool
jitSymbolWeak = JITSymbolFlags
FFI.jitSymbolFlagsWeak JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
.&. JITSymbolFlags
f JITSymbolFlags -> JITSymbolFlags -> Bool
forall a. Eq a => a -> a -> Bool
/= 0,
      jitSymbolCommon :: Bool
jitSymbolCommon = JITSymbolFlags
FFI.jitSymbolFlagsCommon JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
.&. JITSymbolFlags
f JITSymbolFlags -> JITSymbolFlags -> Bool
forall a. Eq a => a -> a -> Bool
/= 0,
      jitSymbolAbsolute :: Bool
jitSymbolAbsolute = JITSymbolFlags
FFI.jitSymbolFlagsAbsolute JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
.&. JITSymbolFlags
f JITSymbolFlags -> JITSymbolFlags -> Bool
forall a. Eq a => a -> a -> Bool
/= 0,
      jitSymbolExported :: Bool
jitSymbolExported = JITSymbolFlags
FFI.jitSymbolFlagsExported JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
.&. JITSymbolFlags
f JITSymbolFlags -> JITSymbolFlags -> Bool
forall a. Eq a => a -> a -> Bool
/= 0
    }

instance MonadIO m => EncodeM m (Either JITSymbolError JITSymbol) (Ptr FFI.JITSymbol -> IO ()) where
  encodeM :: Either JITSymbolError JITSymbol -> m (Ptr JITSymbol -> IO ())
encodeM (Left (JITSymbolError _)) = (Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ()))
-> (Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ())
forall a b. (a -> b) -> a -> b
$ \jitSymbol :: Ptr JITSymbol
jitSymbol ->
    Ptr JITSymbol -> TargetAddress -> JITSymbolFlags -> IO ()
FFI.setJITSymbol Ptr JITSymbol
jitSymbol (Word64 -> TargetAddress
FFI.TargetAddress 0) JITSymbolFlags
FFI.jitSymbolFlagsHasError
  encodeM (Right (JITSymbol addr :: WordPtr
addr flags :: JITSymbolFlags
flags)) = (Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ()))
-> (Ptr JITSymbol -> IO ()) -> m (Ptr JITSymbol -> IO ())
forall a b. (a -> b) -> a -> b
$ \jitSymbol :: Ptr JITSymbol
jitSymbol -> do
    JITSymbolFlags
flags' <- JITSymbolFlags -> IO JITSymbolFlags
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM JITSymbolFlags
flags
    Ptr JITSymbol -> TargetAddress -> JITSymbolFlags -> IO ()
FFI.setJITSymbol Ptr JITSymbol
jitSymbol (Word64 -> TargetAddress
FFI.TargetAddress (WordPtr -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral WordPtr
addr)) JITSymbolFlags
flags'

instance (MonadIO m, MonadAnyCont IO m) => DecodeM m (Either JITSymbolError JITSymbol) (Ptr FFI.JITSymbol) where
  decodeM :: Ptr JITSymbol -> m (Either JITSymbolError JITSymbol)
decodeM jitSymbol :: Ptr JITSymbol
jitSymbol = do
    Ptr (OwnerTransfered CString)
errMsg <- m (Ptr (OwnerTransfered CString))
forall a (m :: * -> *).
(Storable a, MonadAnyCont IO m) =>
m (Ptr a)
alloca
    FFI.TargetAddress addr :: Word64
addr <- IO TargetAddress -> m TargetAddress
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO TargetAddress -> m TargetAddress)
-> IO TargetAddress -> m TargetAddress
forall a b. (a -> b) -> a -> b
$ Ptr JITSymbol -> Ptr (OwnerTransfered CString) -> IO TargetAddress
FFI.getAddress Ptr JITSymbol
jitSymbol Ptr (OwnerTransfered CString)
errMsg
    JITSymbolFlags
rawFlags <- IO JITSymbolFlags -> m JITSymbolFlags
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr JITSymbol -> IO JITSymbolFlags
FFI.getFlags Ptr JITSymbol
jitSymbol)
    if Word64
addr Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== 0 Bool -> Bool -> Bool
|| (JITSymbolFlags
rawFlags JITSymbolFlags -> JITSymbolFlags -> JITSymbolFlags
forall a. Bits a => a -> a -> a
.&. JITSymbolFlags
FFI.jitSymbolFlagsHasError JITSymbolFlags -> JITSymbolFlags -> Bool
forall a. Eq a => a -> a -> Bool
/= 0)
      then do
        ShortByteString
errMsg <- OwnerTransfered CString -> m ShortByteString
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM (OwnerTransfered CString -> m ShortByteString)
-> m (OwnerTransfered CString) -> m ShortByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (OwnerTransfered CString) -> m (OwnerTransfered CString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr JITSymbol -> IO (OwnerTransfered CString)
FFI.getErrorMsg Ptr JITSymbol
jitSymbol)
        Either JITSymbolError JITSymbol
-> m (Either JITSymbolError JITSymbol)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JITSymbolError -> Either JITSymbolError JITSymbol
forall a b. a -> Either a b
Left (ShortByteString -> JITSymbolError
JITSymbolError ShortByteString
errMsg))
      else do
        JITSymbolFlags
flags <- JITSymbolFlags -> m JITSymbolFlags
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM JITSymbolFlags
rawFlags
        Either JITSymbolError JITSymbol
-> m (Either JITSymbolError JITSymbol)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JITSymbol -> Either JITSymbolError JITSymbol
forall a b. b -> Either a b
Right (WordPtr -> JITSymbolFlags -> JITSymbol
JITSymbol (Word64 -> WordPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
addr) JITSymbolFlags
flags))

instance MonadIO m =>
  EncodeM m SymbolResolver (IORef [IO ()] -> Ptr FFI.ExecutionSession -> IO (Ptr FFI.SymbolResolver)) where
  encodeM :: SymbolResolver
-> m (IORef [IO ()]
      -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
encodeM (SymbolResolver resolverFn :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
resolverFn) = (IORef [IO ()] -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
-> m (IORef [IO ()]
      -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
forall (m :: * -> *) a. Monad m => a -> m a
return ((IORef [IO ()] -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
 -> m (IORef [IO ()]
       -> Ptr ExecutionSession -> IO (Ptr SymbolResolver)))
-> (IORef [IO ()]
    -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
-> m (IORef [IO ()]
      -> Ptr ExecutionSession -> IO (Ptr SymbolResolver))
forall a b. (a -> b) -> a -> b
$ \cleanups :: IORef [IO ()]
cleanups es :: Ptr ExecutionSession
es -> do
    FunPtr SymbolResolverFn
resolverFn' <- IORef [IO ()]
-> IO (FunPtr SymbolResolverFn) -> IO (FunPtr SymbolResolverFn)
forall a. IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr IORef [IO ()]
cleanups ((MangledSymbol -> IO (Either JITSymbolError JITSymbol))
-> IO (FunPtr SymbolResolverFn)
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM MangledSymbol -> IO (Either JITSymbolError JITSymbol)
resolverFn)
    IORef [IO ()]
-> IO (Ptr SymbolResolver)
-> (Ptr SymbolResolver -> IO ())
-> IO (Ptr SymbolResolver)
forall a. IORef [IO ()] -> IO a -> (a -> IO ()) -> IO a
allocWithCleanup IORef [IO ()]
cleanups (Ptr ExecutionSession
-> FunPtr SymbolResolverFn -> IO (Ptr SymbolResolver)
FFI.createLambdaResolver Ptr ExecutionSession
es FunPtr SymbolResolverFn
resolverFn') Ptr SymbolResolver -> IO ()
FFI.disposeSymbolResolver

instance MonadIO m => EncodeM m (MangledSymbol -> IO (Either JITSymbolError JITSymbol)) (FunPtr FFI.SymbolResolverFn) where
  encodeM :: (MangledSymbol -> IO (Either JITSymbolError JITSymbol))
-> m (FunPtr SymbolResolverFn)
encodeM callback :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
callback =
    IO (FunPtr SymbolResolverFn) -> m (FunPtr SymbolResolverFn)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (FunPtr SymbolResolverFn) -> m (FunPtr SymbolResolverFn))
-> IO (FunPtr SymbolResolverFn) -> m (FunPtr SymbolResolverFn)
forall a b. (a -> b) -> a -> b
$ SymbolResolverFn -> IO (FunPtr SymbolResolverFn)
FFI.wrapSymbolResolverFn
      (\symbol :: CString
symbol result :: Ptr JITSymbol
result -> do
         Ptr JITSymbol -> IO ()
setSymbol <- Either JITSymbolError JITSymbol -> IO (Ptr JITSymbol -> IO ())
forall (e :: * -> *) h c. EncodeM e h c => h -> e c
encodeM (Either JITSymbolError JITSymbol -> IO (Ptr JITSymbol -> IO ()))
-> IO (Either JITSymbolError JITSymbol)
-> IO (Ptr JITSymbol -> IO ())
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MangledSymbol -> IO (Either JITSymbolError JITSymbol)
callback (MangledSymbol -> IO (Either JITSymbolError JITSymbol))
-> IO MangledSymbol -> IO (Either JITSymbolError JITSymbol)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CString -> IO MangledSymbol
forall (d :: * -> *) h c. DecodeM d h c => c -> d h
decodeM CString
symbol
         Ptr JITSymbol -> IO ()
setSymbol Ptr JITSymbol
result)

-- | Allocate the resource and register it for cleanup.
allocWithCleanup :: IORef [IO ()] -> IO a -> (a -> IO ()) -> IO a
allocWithCleanup :: IORef [IO ()] -> IO a -> (a -> IO ()) -> IO a
allocWithCleanup cleanups :: IORef [IO ()]
cleanups alloc :: IO a
alloc free :: a -> IO ()
free = ((forall a. IO a -> IO a) -> IO a) -> IO a
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO a) -> IO a)
-> ((forall a. IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \restore :: forall a. IO a -> IO a
restore -> do
  a
a <- IO a -> IO a
forall a. IO a -> IO a
restore IO a
alloc
  IORef [IO ()] -> ([IO ()] -> [IO ()]) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef [IO ()]
cleanups (a -> IO ()
free a
a IO () -> [IO ()] -> [IO ()]
forall a. a -> [a] -> [a]
:)
  a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

-- | Allocate a function pointer and register it for cleanup.
allocFunPtr :: IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr :: IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr cleanups :: IORef [IO ()]
cleanups alloc :: IO (FunPtr a)
alloc = IORef [IO ()]
-> IO (FunPtr a) -> (FunPtr a -> IO ()) -> IO (FunPtr a)
forall a. IORef [IO ()] -> IO a -> (a -> IO ()) -> IO a
allocWithCleanup IORef [IO ()]
cleanups IO (FunPtr a)
alloc FunPtr a -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr

createRegisteredDataLayout :: (MonadAnyCont IO m) => TargetMachine -> IORef [IO ()] -> m (Ptr FFI.DataLayout)
createRegisteredDataLayout :: TargetMachine -> IORef [IO ()] -> m (Ptr DataLayout)
createRegisteredDataLayout (TargetMachine tm :: Ptr TargetMachine
tm) cleanups :: IORef [IO ()]
cleanups =
  let createDataLayout :: IO (Ptr DataLayout)
createDataLayout = do
        Ptr DataLayout
dl <- Ptr TargetMachine -> IO (Ptr DataLayout)
FFI.createTargetDataLayout Ptr TargetMachine
tm
        IORef [IO ()] -> ([IO ()] -> [IO ()]) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef [IO ()]
cleanups (Ptr DataLayout -> IO ()
FFI.disposeDataLayout Ptr DataLayout
dl IO () -> [IO ()] -> [IO ()]
forall a. a -> [a] -> [a]
:)
        Ptr DataLayout -> IO (Ptr DataLayout)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ptr DataLayout
dl
  in (forall r. (Ptr DataLayout -> IO r) -> IO r) -> m (Ptr DataLayout)
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r. (Ptr DataLayout -> IO r) -> IO r)
 -> m (Ptr DataLayout))
-> (forall r. (Ptr DataLayout -> IO r) -> IO r)
-> m (Ptr DataLayout)
forall a b. (a -> b) -> a -> b
$ IO (Ptr DataLayout)
-> (Ptr DataLayout -> IO ()) -> (Ptr DataLayout -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO (Ptr DataLayout)
createDataLayout Ptr DataLayout -> IO ()
FFI.disposeDataLayout

-- | Create a new `ExecutionSession`.
createExecutionSession :: IO ExecutionSession
createExecutionSession :: IO ExecutionSession
createExecutionSession = Ptr ExecutionSession -> ExecutionSession
ExecutionSession (Ptr ExecutionSession -> ExecutionSession)
-> IO (Ptr ExecutionSession) -> IO ExecutionSession
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Ptr ExecutionSession)
FFI.createExecutionSession

-- | Dispose of an `ExecutionSession`. This should be called when the
-- `ExecutionSession` is not needed anymore.
disposeExecutionSession :: ExecutionSession -> IO ()
disposeExecutionSession :: ExecutionSession -> IO ()
disposeExecutionSession (ExecutionSession es :: Ptr ExecutionSession
es) = Ptr ExecutionSession -> IO ()
FFI.disposeExecutionSession Ptr ExecutionSession
es

-- | `bracket`-style wrapper around `createExecutionSession` and
-- `disposeExecutionSession`.
withExecutionSession :: (ExecutionSession -> IO a) -> IO a
withExecutionSession :: (ExecutionSession -> IO a) -> IO a
withExecutionSession = IO ExecutionSession
-> (ExecutionSession -> IO ())
-> (ExecutionSession -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO ExecutionSession
createExecutionSession ExecutionSession -> IO ()
disposeExecutionSession

-- | Allocate a module key for a new module to add to the JIT.
allocateModuleKey :: ExecutionSession -> IO FFI.ModuleKey
allocateModuleKey :: ExecutionSession -> IO ModuleKey
allocateModuleKey (ExecutionSession es :: Ptr ExecutionSession
es) = Ptr ExecutionSession -> IO ModuleKey
FFI.allocateVModule Ptr ExecutionSession
es

-- | Return a module key to the `ExecutionSession` so that it can be
-- re-used.
releaseModuleKey :: ExecutionSession -> FFI.ModuleKey -> IO ()
releaseModuleKey :: ExecutionSession -> ModuleKey -> IO ()
releaseModuleKey (ExecutionSession es :: Ptr ExecutionSession
es) k :: ModuleKey
k = Ptr ExecutionSession -> ModuleKey -> IO ()
FFI.releaseVModule Ptr ExecutionSession
es ModuleKey
k

-- | `bracket`-style wrapper around `allocateModuleKey` and
-- `releaseModuleKey`.
withModuleKey :: ExecutionSession -> (FFI.ModuleKey -> IO a) -> IO a
withModuleKey :: ExecutionSession -> (ModuleKey -> IO a) -> IO a
withModuleKey es :: ExecutionSession
es = IO ModuleKey -> (ModuleKey -> IO ()) -> (ModuleKey -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (ExecutionSession -> IO ModuleKey
allocateModuleKey ExecutionSession
es) (ExecutionSession -> ModuleKey -> IO ()
releaseModuleKey ExecutionSession
es)