{-# LANGUAGE TemplateHaskell, QuasiQuotes, MultiParamTypeClasses, ScopedTypeVariables #-} module LLVM.Internal.Constant where import LLVM.Prelude import qualified Language.Haskell.TH as TH import qualified Language.Haskell.TH.Quote as TH import qualified LLVM.Internal.InstructionDefs as ID import Data.Bits import Control.Monad.AnyCont import Control.Monad.Catch import Control.Monad.IO.Class import Control.Monad.State (get, gets, modify, evalState) import qualified Data.Map as Map import Foreign.Ptr import Foreign.Storable (Storable, sizeOf) import qualified LLVM.Internal.FFI.Constant as FFI import qualified LLVM.Internal.FFI.GlobalValue as FFI import qualified LLVM.Internal.FFI.Instruction as FFI import qualified LLVM.Internal.FFI.LLVMCTypes as FFI import LLVM.Internal.FFI.LLVMCTypes (valueSubclassIdP) import qualified LLVM.Internal.FFI.PtrHierarchy as FFI import qualified LLVM.Internal.FFI.Type as FFI import qualified LLVM.Internal.FFI.User as FFI import qualified LLVM.Internal.FFI.Value as FFI import qualified LLVM.Internal.FFI.BinaryOperator as FFI import qualified LLVM.AST.Constant as A (Constant) import qualified LLVM.AST.Constant as A.C hiding (Constant) import qualified LLVM.AST.Type as A import qualified LLVM.AST.IntegerPredicate as A (IntegerPredicate) import qualified LLVM.AST.FloatingPointPredicate as A (FloatingPointPredicate) import qualified LLVM.AST.Float as A.F import LLVM.Exception import LLVM.Internal.Coding import LLVM.Internal.Context import LLVM.Internal.DecodeAST import LLVM.Internal.EncodeAST import LLVM.Internal.FloatingPointPredicate () import LLVM.Internal.IntegerPredicate () import LLVM.Internal.Type () import LLVM.Internal.Value allocaWords :: forall a m . (Storable a, MonadAnyCont IO m, Monad m, MonadIO m) => Word32 -> m (Ptr a) allocaWords nBits = do allocaArray (((nBits-1) `div` (8*(fromIntegral (sizeOf (undefined :: a))))) + 1) inconsistentCases :: Show a => String -> a -> b inconsistentCases name attr = error $ "llvm-hs internal error: cases inconstistent in " ++ name ++ " encoding for " ++ show attr instance EncodeM EncodeAST A.Constant (Ptr FFI.Constant) where encodeM c = scopeAnyCont $ case c of A.C.Int { A.C.integerBits = bits, A.C.integerValue = v } -> do t <- encodeM (A.IntegerType bits) words <- encodeM [ fromIntegral ((v `shiftR` (w*64)) .&. 0xffffffffffffffff) :: Word64 | w <- [0 .. ((fromIntegral bits-1) `div` 64)] ] liftIO $ FFI.constantIntOfArbitraryPrecision t words A.C.Float { A.C.floatValue = v } -> do Context context <- gets encodeStateContext let poke1 f = do let nBits = fromIntegral $ 8*(sizeOf f) words <- allocaWords nBits poke (castPtr words) f return (nBits, words) poke2 fh fl = do let nBits = fromIntegral $ 8*(sizeOf fh) + 8*(sizeOf fl) words <- allocaWords nBits pokeByteOff (castPtr words) 0 fl pokeByteOff (castPtr words) (sizeOf fl) fh return (nBits, words) (nBits, words) <- case v of A.F.Half f -> poke1 f A.F.Single f -> poke1 f A.F.Double f -> poke1 f A.F.X86_FP80 high low -> poke2 high low A.F.Quadruple high low -> poke2 high low A.F.PPC_FP128 high low -> poke2 high low let fpSem = case v of A.F.Half _ -> FFI.floatSemanticsIEEEhalf A.F.Single _ -> FFI.floatSemanticsIEEEsingle A.F.Double _ -> FFI.floatSemanticsIEEEdouble A.F.Quadruple _ _ -> FFI.floatSemanticsIEEEquad A.F.X86_FP80 _ _ -> FFI.floatSemanticsx87DoubleExtended A.F.PPC_FP128 _ _ -> FFI.floatSemanticsPPCDoubleDouble nBits <- encodeM nBits liftIO $ FFI.constantFloatOfArbitraryPrecision context nBits words fpSem A.C.GlobalReference ty n -> do ref <- FFI.upCast <$> referGlobal n ty' <- (liftIO . runDecodeAST . typeOf) ref if ty /= ty' then throwM (EncodeException ("The serialized GlobalReference has type " ++ show ty' ++ " but should have type " ++ show ty)) else return ref A.C.BlockAddress f b -> do f' <- referGlobal f b' <- getBlockForAddress f b liftIO $ FFI.blockAddress (FFI.upCast f') b' A.C.Struct nm p ms -> do p <- encodeM p ms <- encodeM ms case nm of Nothing -> do Context context <- gets encodeStateContext liftIO $ FFI.constStructInContext context ms p Just nm -> do t <- lookupNamedType nm liftIO $ FFI.constNamedStruct t ms A.C.TokenNone -> do Context context <- gets encodeStateContext liftIO $ FFI.getConstTokenNone context o -> $(do let constExprInfo = ID.outerJoin ID.astConstantRecs (ID.innerJoin ID.astInstructionRecs ID.instructionDefs) TH.caseE [| o |] $ map (\p -> TH.match p (TH.normalB [|inconsistentCases "Constant" o|]) []) [[p|A.C.Int{}|], [p|A.C.Float{}|], [p|A.C.Struct{}|], [p|A.C.BlockAddress{}|], [p|A.C.GlobalReference{}|], [p|A.C.TokenNone{}|]] ++ (do (name, (Just (TH.RecC n fs), instrInfo)) <- Map.toList constExprInfo let fns = [ TH.mkName . TH.nameBase $ fn | (fn, _, _) <- fs ] coreCall n = TH.dyn $ "FFI.constant" ++ n buildBody c = [ TH.bindS (TH.varP fn) [| encodeM $(TH.varE fn) |] | fn <- fns ] ++ [ TH.noBindS [| liftIO $(foldl TH.appE c (map TH.varE fns)) |] ] hasFlags = any (== ''Bool) [ h | (_, _, TH.ConT h) <- fs ] core <- case instrInfo of Just (_, iDef) -> do let opcode = TH.dataToExpQ (const Nothing) (ID.cppOpcode iDef) case ID.instructionKind iDef of ID.Binary | hasFlags -> return $ coreCall name | True -> return [| $(coreCall "BinaryOperator") $(opcode) |] ID.Cast -> return [| $(coreCall "Cast") $(opcode) |] _ -> return $ coreCall name Nothing -> if (name `elem` ["Vector", "Null", "Array", "Undef"]) then return $ coreCall name else [] return $ TH.match (TH.recP n [(fn,) <$> (TH.varP . TH.mkName . TH.nameBase $ fn) | (fn, _, _) <- fs]) (TH.normalB (TH.doE (buildBody core))) []) ) instance DecodeM DecodeAST A.Constant (Ptr FFI.Constant) where decodeM c = scopeAnyCont $ do let v = FFI.upCast c :: Ptr FFI.Value u = FFI.upCast c :: Ptr FFI.User ft <- liftIO (FFI.typeOf v) t <- decodeM ft valueSubclassId <- liftIO $ FFI.getValueSubclassId v nOps <- liftIO $ FFI.getNumOperands u let globalRef = return A.C.GlobalReference `ap` (return t) `ap` (getGlobalName =<< liftIO (FFI.isAGlobalValue v)) op = decodeM <=< liftIO . FFI.getConstantOperand c getConstantOperands = mapM op [0..nOps-1] getConstantData = do let nElements = case t of A.VectorType n _ -> n A.ArrayType n _ | n <= (fromIntegral (maxBound :: Word32)) -> fromIntegral n _ -> error "getConstantData can only be applied to vectors and arrays" forM [0..nElements-1] $ do decodeM <=< liftIO . FFI.getConstantDataSequentialElementAsConstant c . fromIntegral case valueSubclassId of [valueSubclassIdP|Function|] -> globalRef [valueSubclassIdP|GlobalAlias|] -> globalRef [valueSubclassIdP|GlobalVariable|] -> globalRef [valueSubclassIdP|ConstantInt|] -> do np <- alloca wsp <- liftIO $ FFI.getConstantIntWords c np n <- peek np words <- decodeM (n, wsp) return $ A.C.Int (A.typeBits t) (foldr (\b a -> (a `shiftL` 64) .|. fromIntegral b) 0 (words :: [Word64])) [valueSubclassIdP|ConstantFP|] -> do let A.FloatingPointType fpt = t let nBits = case fpt of A.HalfFP -> 16 A.FloatFP -> 32 A.DoubleFP -> 64 A.FP128FP -> 128 A.X86_FP80FP -> 80 A.PPC_FP128FP -> 128 ws <- allocaWords nBits liftIO $ FFI.getConstantFloatWords c ws A.C.Float <$> ( case fpt of A.HalfFP -> A.F.Half <$> peek (castPtr ws) A.FloatFP -> A.F.Single <$> peek (castPtr ws) A.DoubleFP -> A.F.Double <$> peek (castPtr ws) A.FP128FP -> A.F.Quadruple <$> peekByteOff (castPtr ws) 8 <*> peekByteOff (castPtr ws) 0 A.X86_FP80FP -> A.F.X86_FP80 <$> peekByteOff (castPtr ws) 8 <*> peekByteOff (castPtr ws) 0 A.PPC_FP128FP -> A.F.PPC_FP128 <$> peekByteOff (castPtr ws) 8 <*> peekByteOff (castPtr ws) 0 ) [valueSubclassIdP|ConstantPointerNull|] -> return $ A.C.Null t [valueSubclassIdP|ConstantAggregateZero|] -> return $ A.C.Null t [valueSubclassIdP|UndefValue|] -> return $ A.C.Undef t [valueSubclassIdP|BlockAddress|] -> return A.C.BlockAddress `ap` (getGlobalName =<< do liftIO $ FFI.isAGlobalValue =<< FFI.getBlockAddressFunction c) `ap` (getLocalName =<< do liftIO $ FFI.getBlockAddressBlock c) [valueSubclassIdP|ConstantStruct|] -> do return A.C.Struct `ap` (return $ case t of A.NamedTypeReference n -> Just n; _ -> Nothing) `ap` (decodeM =<< liftIO (FFI.isPackedStruct ft)) `ap` getConstantOperands [valueSubclassIdP|ConstantDataArray|] -> return A.C.Array `ap` (return $ A.elementType t) `ap` getConstantData [valueSubclassIdP|ConstantArray|] -> return A.C.Array `ap` (return $ A.elementType t) `ap` getConstantOperands [valueSubclassIdP|ConstantDataVector|] -> return A.C.Vector `ap` getConstantData [valueSubclassIdP|ConstantVector|] -> A.C.Vector <$> getConstantOperands [valueSubclassIdP|ConstantExpr|] -> do cppOpcode <- liftIO $ FFI.getConstantCPPOpcode c $( TH.caseE [| cppOpcode |] $ (do (_, ((TH.RecC n fs, _), iDef)) <- Map.toList $ ID.innerJoin (ID.innerJoin ID.astConstantRecs ID.astInstructionRecs) ID.instructionDefs let apWrapper o (fn, _, ct) = do a <- case ct of TH.ConT h | h == ''A.Constant -> do operandNumber <- get modify (+1) return [| op $(TH.litE . TH.integerL $ operandNumber) |] | h == ''A.Type -> return [| pure t |] | h == ''A.IntegerPredicate -> return [| liftIO $ decodeM =<< FFI.getConstantICmpPredicate c |] | h == ''A.FloatingPointPredicate -> return [| liftIO $ decodeM =<< FFI.getConstantFCmpPredicate c |] | h == ''Bool -> case TH.nameBase fn of "inBounds" -> return [| liftIO $ decodeM =<< FFI.getInBounds v |] "exact" -> return [| liftIO $ decodeM =<< FFI.isExact v |] "nsw" -> return [| liftIO $ decodeM =<< FFI.hasNoSignedWrap v |] "nuw" -> return [| liftIO $ decodeM =<< FFI.hasNoUnsignedWrap v |] x -> error $ "constant bool field " ++ show x ++ " not handled yet" TH.AppT TH.ListT (TH.ConT h) | h == ''Word32 -> return [| do np <- alloca isp <- liftIO $ FFI.getConstantIndices c np n <- peek np decodeM (n, isp) |] | h == ''A.Constant && TH.nameBase fn == "indices" -> do operandNumber <- get return [| mapM op [$(TH.litE . TH.integerL $ operandNumber)..nOps-1] |] _ -> error $ "unhandled constant expr field type: " ++ show fn ++ " - " ++ show ct return [| $(o) `ap` $(a) |] return $ TH.match (TH.dataToPatQ (const Nothing) (ID.cppOpcode iDef)) (TH.normalB (evalState (foldM apWrapper [| return $(TH.conE n) |] fs) 0)) []) ++ [TH.match TH.wildP (TH.normalB [|error ("Unknown constant opcode: " <> show cppOpcode)|]) []] ) [valueSubclassIdP|ConstantTokenNone|] -> return A.C.TokenNone _ -> error $ "unhandled constant valueSubclassId: " ++ show valueSubclassId