{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-} module LLVM.Test.OrcJIT where import Test.Tasty import Test.Tasty.HUnit import LLVM.Test.Support import qualified Data.Map.Strict as Map import Control.Applicative import Data.ByteString (ByteString) import Data.Foldable import Data.IORef import Data.Word import Foreign.Ptr import System.Process (callProcess) import System.IO.Temp (withSystemTempFile) import System.IO import LLVM.Internal.PassManager import LLVM.Internal.ObjectFile (withObjectFile) import qualified LLVM.Internal.FFI.PassManager as FFI import LLVM.Context import LLVM.Module import qualified LLVM.Internal.FFI.Module as FFI import LLVM.OrcJIT import qualified LLVM.Internal.OrcJIT.CompileLayer as CL import qualified LLVM.Internal.OrcJIT.LinkingLayer as LL import LLVM.Target testModule :: ByteString testModule = "; ModuleID = ''\n\ \source_filename = \"\"\n\ \\n\ \declare i32 @testFunc()\n\ \define i32 @main(i32, i8**) {\n\ \ %3 = call i32 @testFunc()\n\ \ ret i32 %3\n\ \}\n" withTestModule :: (Module -> IO a) -> IO a withTestModule f = withContext $ \context -> withModuleFromLLVMAssembly' context testModule f myTestFuncImpl :: IO Word32 myTestFuncImpl = return 42 foreign import ccall "wrapper" wrapTestFunc :: IO Word32 -> IO (FunPtr (IO Word32)) foreign import ccall "dynamic" mkMain :: FunPtr (IO Word32) -> IO Word32 resolver :: CompileLayer l => MangledSymbol -> l -> MangledSymbol -> IO (Either JITSymbolError JITSymbol) resolver testFunc compileLayer symbol = do if symbol /= testFunc then CL.findSymbol compileLayer symbol True else do funPtr <- wrapTestFunc myTestFuncImpl let addr = ptrToWordPtr (castFunPtrToPtr funPtr) return (Right (JITSymbol addr defaultJITSymbolFlags)) nullResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol) nullResolver s = putStrLn "nullresolver" >> return (Left (JITSymbolError "unknown symbol")) moduleTransform :: IORef Bool -> Ptr FFI.Module -> IO (Ptr FFI.Module) moduleTransform passmanagerSuccessful modulePtr = do withPassManager defaultCuratedPassSetSpec { optLevel = Just 2 } $ \(PassManager pm) -> do success <- toEnum . fromIntegral <$> FFI.runPassManager pm modulePtr writeIORef passmanagerSuccessful success pure modulePtr tests :: TestTree tests = testGroup "OrcJit" [ testCase "eager compilation" $ do resolvers <- newIORef Map.empty withTestModule $ \mod -> withHostTargetMachine $ \tm -> withExecutionSession $ \es -> withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \linkingLayer -> withIRCompileLayer linkingLayer tm $ \compileLayer -> do testFunc <- mangleSymbol compileLayer "testFunc" withModuleKey es $ \k -> withSymbolResolver es (SymbolResolver (resolver testFunc compileLayer)) $ \resolver -> do modifyIORef' resolvers (Map.insert k resolver) withModule compileLayer k mod $ do mainSymbol <- mangleSymbol compileLayer "main" Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 42 Right (JITSymbol mainFn _) <- CL.findSymbolIn compileLayer k mainSymbol True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 42 unknownSymbol <- mangleSymbol compileLayer "unknownSymbol" unknownSymbolRes <- CL.findSymbol compileLayer unknownSymbol True unknownSymbolRes @?= Left (JITSymbolError mempty), testCase "IRTransformLayer" $ do passmanagerSuccessful <- newIORef False resolvers <- newIORef Map.empty withTestModule $ \mod -> withHostTargetMachine $ \tm -> withExecutionSession $ \es -> withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \linkingLayer -> withIRCompileLayer linkingLayer tm $ \compileLayer -> withIRTransformLayer compileLayer tm (moduleTransform passmanagerSuccessful) $ \compileLayer -> withModuleKey es $ \k -> do testFunc <- mangleSymbol compileLayer "testFunc" withSymbolResolver es (SymbolResolver (resolver testFunc compileLayer)) $ \resolver -> do modifyIORef' resolvers (Map.insert k resolver) withModule compileLayer k mod $ do mainSymbol <- mangleSymbol compileLayer "main" Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 42 readIORef passmanagerSuccessful @? "passmanager failed", testCase "lazy compilation" $ do resolvers <- newIORef Map.empty let getResolver k = fmap (Map.! k) (readIORef resolvers) setResolver k r = modifyIORef' resolvers (Map.insert k r) withTestModule $ \mod -> withHostTargetMachine $ \tm -> do triple <- getTargetMachineTriple tm withExecutionSession $ \es -> withObjectLinkingLayer es getResolver $ \linkingLayer -> withIRCompileLayer linkingLayer tm $ \baseLayer -> withIndirectStubsManagerBuilder triple $ \stubsMgr -> withJITCompileCallbackManager es triple Nothing $ \callbackMgr -> withCompileOnDemandLayer es baseLayer tm getResolver setResolver (\x -> return [x]) callbackMgr stubsMgr False $ \compileLayer -> do testFunc <- mangleSymbol compileLayer "testFunc" withModuleKey es $ \k -> withSymbolResolver es (SymbolResolver (resolver testFunc baseLayer)) $ \resolver -> do setResolver k resolver withModule compileLayer k mod $ do mainSymbol <- mangleSymbol compileLayer "main" Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 42, testCase "finding symbols in linking layer" $ withExecutionSession $ \es -> withModuleKey es $ \k -> withSymbolResolver es (SymbolResolver nullResolver) $ \resolver -> withObjectLinkingLayer es (\_ -> pure resolver) $ \linkingLayer -> do let inputPath = "./test/main_return_38.c" withSystemTempFile "main.o" $ \outputPath _ -> do callProcess "gcc" ["-c", inputPath, "-o", outputPath] withObjectFile outputPath $ \objFile -> do addObjectFile linkingLayer k objFile -- Find main symbol by looking into global linking context Right (JITSymbol mainFn _) <- LL.findSymbol linkingLayer "main" True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 38 -- Find main symbol by specificly using object handle for given object file Right (JITSymbol mainFn _) <- LL.findSymbolIn linkingLayer k "main" True result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) result @?= 38 ]