module LLVM.Internal.OrcJIT.IRTransformLayer where

import LLVM.Prelude

import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.IORef
import Foreign.Ptr

import qualified LLVM.Internal.FFI.DataLayout as FFI
import qualified LLVM.Internal.FFI.Module as FFI
import qualified LLVM.Internal.FFI.OrcJIT.IRTransformLayer as FFI
import qualified LLVM.Internal.FFI.PtrHierarchy as FFI
import LLVM.Internal.OrcJIT
import LLVM.Internal.OrcJIT.CompileLayer
import LLVM.Internal.Target

-- | 'IRTransformLayer' allows transforming modules before handing off
-- compilation to the underlying 'CompileLayer'.
data IRTransformLayer baseLayer =
  IRTransformLayer {
    IRTransformLayer baseLayer -> Ptr IRTransformLayer
compileLayer :: !(Ptr FFI.IRTransformLayer),
    IRTransformLayer baseLayer -> Ptr DataLayout
dataLayout :: !(Ptr FFI.DataLayout),
    IRTransformLayer baseLayer -> IORef [IO ()]
cleanupActions :: !(IORef [IO ()])
  }
  deriving IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
(IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool)
-> (IRTransformLayer baseLayer
    -> IRTransformLayer baseLayer -> Bool)
-> Eq (IRTransformLayer baseLayer)
forall baseLayer.
IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
$c/= :: forall baseLayer.
IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
== :: IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
$c== :: forall baseLayer.
IRTransformLayer baseLayer -> IRTransformLayer baseLayer -> Bool
Eq

instance CompileLayer (IRTransformLayer l) where
  getCompileLayer :: IRTransformLayer l -> Ptr CompileLayer
getCompileLayer = Ptr IRTransformLayer -> Ptr CompileLayer
forall a b. DescendentOf a b => Ptr b -> Ptr a
FFI.upCast (Ptr IRTransformLayer -> Ptr CompileLayer)
-> (IRTransformLayer l -> Ptr IRTransformLayer)
-> IRTransformLayer l
-> Ptr CompileLayer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IRTransformLayer l -> Ptr IRTransformLayer
forall baseLayer.
IRTransformLayer baseLayer -> Ptr IRTransformLayer
compileLayer
  getDataLayout :: IRTransformLayer l -> Ptr DataLayout
getDataLayout = IRTransformLayer l -> Ptr DataLayout
forall l. IRTransformLayer l -> Ptr DataLayout
dataLayout
  getCleanups :: IRTransformLayer l -> IORef [IO ()]
getCleanups = IRTransformLayer l -> IORef [IO ()]
forall l. IRTransformLayer l -> IORef [IO ()]
cleanupActions

-- | Create a new 'IRTransformLayer'.
--
-- When the layer is no longer needed, it should be disposed using 'disposeCompileLayer'.
newIRTransformLayer
  :: CompileLayer l
  => l
  -> TargetMachine
  -> (Ptr FFI.Module -> IO (Ptr FFI.Module)) {- ^ module transformation -}
  -> IO (IRTransformLayer l)
newIRTransformLayer :: l
-> TargetMachine
-> (Ptr Module -> IO (Ptr Module))
-> IO (IRTransformLayer l)
newIRTransformLayer compileLayer :: l
compileLayer tm :: TargetMachine
tm moduleTransform :: Ptr Module -> IO (Ptr Module)
moduleTransform =
  (AnyContT IO (IRTransformLayer l)
 -> (IRTransformLayer l -> IO (IRTransformLayer l))
 -> IO (IRTransformLayer l))
-> (IRTransformLayer l -> IO (IRTransformLayer l))
-> AnyContT IO (IRTransformLayer l)
-> IO (IRTransformLayer l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AnyContT IO (IRTransformLayer l)
-> (IRTransformLayer l -> IO (IRTransformLayer l))
-> IO (IRTransformLayer l)
forall (m :: * -> *) a. AnyContT m a -> forall r. (a -> m r) -> m r
runAnyContT IRTransformLayer l -> IO (IRTransformLayer l)
forall (m :: * -> *) a. Monad m => a -> m a
return (AnyContT IO (IRTransformLayer l) -> IO (IRTransformLayer l))
-> AnyContT IO (IRTransformLayer l) -> IO (IRTransformLayer l)
forall a b. (a -> b) -> a -> b
$ do
    IORef [IO ()]
cleanups <- IO (IORef [IO ()]) -> AnyContT IO (IORef [IO ()])
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ([IO ()] -> IO (IORef [IO ()])
forall a. a -> IO (IORef a)
newIORef [])
    Ptr DataLayout
dl <- TargetMachine -> IORef [IO ()] -> AnyContT IO (Ptr DataLayout)
forall (m :: * -> *).
MonadAnyCont IO m =>
TargetMachine -> IORef [IO ()] -> m (Ptr DataLayout)
createRegisteredDataLayout TargetMachine
tm IORef [IO ()]
cleanups
    let encodedModuleTransform :: IO (FunPtr (Ptr Module -> IO (Ptr Module)))
encodedModuleTransform =
          IORef [IO ()]
-> IO (FunPtr (Ptr Module -> IO (Ptr Module)))
-> IO (FunPtr (Ptr Module -> IO (Ptr Module)))
forall a. IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr IORef [IO ()]
cleanups ((Ptr Module -> IO (Ptr Module))
-> IO (FunPtr (Ptr Module -> IO (Ptr Module)))
FFI.wrapModuleTransform Ptr Module -> IO (Ptr Module)
moduleTransform)
    FunPtr (Ptr Module -> IO (Ptr Module))
moduleTransform' <-
      (forall r.
 (FunPtr (Ptr Module -> IO (Ptr Module)) -> IO r) -> IO r)
-> AnyContT IO (FunPtr (Ptr Module -> IO (Ptr Module)))
forall (b :: * -> *) (m :: * -> *) a.
MonadAnyCont b m =>
(forall r. (a -> b r) -> b r) -> m a
anyContToM ((forall r.
  (FunPtr (Ptr Module -> IO (Ptr Module)) -> IO r) -> IO r)
 -> AnyContT IO (FunPtr (Ptr Module -> IO (Ptr Module))))
-> (forall r.
    (FunPtr (Ptr Module -> IO (Ptr Module)) -> IO r) -> IO r)
-> AnyContT IO (FunPtr (Ptr Module -> IO (Ptr Module)))
forall a b. (a -> b) -> a -> b
$ IO (FunPtr (Ptr Module -> IO (Ptr Module)))
-> (FunPtr (Ptr Module -> IO (Ptr Module)) -> IO ())
-> (FunPtr (Ptr Module -> IO (Ptr Module)) -> IO r)
-> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO (FunPtr (Ptr Module -> IO (Ptr Module)))
encodedModuleTransform FunPtr (Ptr Module -> IO (Ptr Module)) -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    Ptr IRTransformLayer
cl <-
      IO (Ptr IRTransformLayer) -> AnyContT IO (Ptr IRTransformLayer)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
        (Ptr CompileLayer
-> FunPtr (Ptr Module -> IO (Ptr Module))
-> IO (Ptr IRTransformLayer)
FFI.createIRTransformLayer
           (l -> Ptr CompileLayer
forall l. CompileLayer l => l -> Ptr CompileLayer
getCompileLayer l
compileLayer)
           FunPtr (Ptr Module -> IO (Ptr Module))
moduleTransform')
    IRTransformLayer l -> AnyContT IO (IRTransformLayer l)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr IRTransformLayer
-> Ptr DataLayout -> IORef [IO ()] -> IRTransformLayer l
forall baseLayer.
Ptr IRTransformLayer
-> Ptr DataLayout -> IORef [IO ()] -> IRTransformLayer baseLayer
IRTransformLayer Ptr IRTransformLayer
cl Ptr DataLayout
dl IORef [IO ()]
cleanups)

-- | 'bracket'-style wrapper around 'newIRTransformLayer' and 'disposeCompileLayer'.
withIRTransformLayer ::
     CompileLayer l
  => l
  -> TargetMachine
  -> (Ptr FFI.Module -> IO (Ptr FFI.Module)) {- ^ module transformation -}
  -> (IRTransformLayer l -> IO a)
  -> IO a
withIRTransformLayer :: l
-> TargetMachine
-> (Ptr Module -> IO (Ptr Module))
-> (IRTransformLayer l -> IO a)
-> IO a
withIRTransformLayer compileLayer :: l
compileLayer tm :: TargetMachine
tm moduleTransform :: Ptr Module -> IO (Ptr Module)
moduleTransform =
  IO (IRTransformLayer l)
-> (IRTransformLayer l -> IO ())
-> (IRTransformLayer l -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (l
-> TargetMachine
-> (Ptr Module -> IO (Ptr Module))
-> IO (IRTransformLayer l)
forall l.
CompileLayer l =>
l
-> TargetMachine
-> (Ptr Module -> IO (Ptr Module))
-> IO (IRTransformLayer l)
newIRTransformLayer l
compileLayer TargetMachine
tm Ptr Module -> IO (Ptr Module)
moduleTransform)
    IRTransformLayer l -> IO ()
forall l. CompileLayer l => l -> IO ()
disposeCompileLayer