{- Copyright (C) 2009 John Millikin This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE BangPatterns #-} module DBus.Wire.Internal where import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as TL import qualified Control.Monad.State as ST import qualified Control.Monad.Error as E import qualified Data.ByteString.Lazy as L import qualified Data.Binary.Builder as B import qualified Data.Binary.Get as G import Data.Binary.Put (runPut) import qualified Data.Binary.IEEE754 as IEEE import Data.Text.Lazy.Encoding (encodeUtf8, decodeUtf8) import Data.Text.Encoding.Error (UnicodeException) import qualified Control.Exception as Exc import System.IO.Unsafe (unsafePerformIO) import qualified DBus.Constants as C import qualified DBus.Message.Internal as M import Data.Bits ((.|.), (.&.)) import qualified Data.Set as Set import Control.Monad (when, unless) import Data.Maybe (fromJust, listToMaybe, fromMaybe) import Data.Word (Word8, Word32, Word64) import Data.Int (Int16, Int32, Int64) import qualified DBus.Types as T data Endianness = LittleEndian | BigEndian deriving (Show, Eq) encodeEndianness :: Endianness -> Word8 encodeEndianness LittleEndian = 108 encodeEndianness BigEndian = 66 decodeEndianness :: Word8 -> Maybe Endianness decodeEndianness 108 = Just LittleEndian decodeEndianness 66 = Just BigEndian decodeEndianness _ = Nothing alignment :: T.Type -> Word8 alignment T.DBusByte = 1 alignment T.DBusWord16 = 2 alignment T.DBusWord32 = 4 alignment T.DBusWord64 = 8 alignment T.DBusInt16 = 2 alignment T.DBusInt32 = 4 alignment T.DBusInt64 = 8 alignment T.DBusDouble = 8 alignment T.DBusBoolean = 4 alignment T.DBusString = 4 alignment T.DBusObjectPath = 4 alignment T.DBusSignature = 1 alignment (T.DBusArray _) = 4 alignment (T.DBusDictionary _ _) = 4 alignment (T.DBusStructure _) = 8 alignment T.DBusVariant = 1 padding :: Word64 -> Word8 -> Word64 padding current count = required where count' = fromIntegral count missing = mod current count' required = if missing > 0 then count' - missing else 0 data MarshalState = MarshalState Endianness B.Builder !Word64 newtype MarshalM a = MarshalM (E.ErrorT MarshalError (ST.State MarshalState) a) deriving (Monad, E.MonadError MarshalError, ST.MonadState MarshalState) type Marshal = MarshalM () runMarshal :: Marshal -> Endianness -> Either MarshalError L.ByteString runMarshal (MarshalM m) e = case ST.runState (E.runErrorT m) initialState of (Right _, MarshalState _ builder _) -> Right (B.toLazyByteString builder) (Left x, _) -> Left x where initialState = MarshalState e B.empty 0 marshal :: T.Variant -> Marshal marshal v = marshalType (T.variantType v) where x :: T.Variable a => a x = fromJust . T.fromVariant $ v marshalType :: T.Type -> Marshal marshalType T.DBusByte = append $ L.singleton x marshalType T.DBusWord16 = marshalBuilder 2 B.putWord16be B.putWord16le x marshalType T.DBusWord32 = marshalBuilder 4 B.putWord32be B.putWord32le x marshalType T.DBusWord64 = marshalBuilder 8 B.putWord64be B.putWord64le x marshalType T.DBusInt16 = marshalBuilder 2 B.putWord16be B.putWord16le $ fromIntegral (x :: Int16) marshalType T.DBusInt32 = marshalBuilder 4 B.putWord32be B.putWord32le $ fromIntegral (x :: Int32) marshalType T.DBusInt64 = marshalBuilder 8 B.putWord64be B.putWord64le $ fromIntegral (x :: Int64) marshalType T.DBusDouble = do pad 8 (MarshalState e _ _) <- ST.get let put = case e of BigEndian -> IEEE.putFloat64be LittleEndian -> IEEE.putFloat64le let bytes = runPut $ put x append bytes marshalType T.DBusBoolean = marshalWord32 $ if x then 1 else 0 marshalType T.DBusString = marshalText x marshalType T.DBusObjectPath = marshalText . T.strObjectPath $ x marshalType T.DBusSignature = marshalSignature x marshalType (T.DBusArray _) = marshalArray x marshalType (T.DBusDictionary _ _) = marshalArray (T.dictionaryToArray x) marshalType (T.DBusStructure _) = do let T.Structure vs = x pad 8 mapM_ marshal vs marshalType T.DBusVariant = do let rawSig = T.typeCode . T.variantType $ x sig <- case T.mkSignature rawSig of Just x' -> return x' Nothing -> E.throwError $ InvalidVariantSignature rawSig marshalSignature sig marshal x append :: L.ByteString -> Marshal append bytes = do (MarshalState e builder count) <- ST.get let builder' = B.append builder $ B.fromLazyByteString bytes count' = count + (fromIntegral $ L.length bytes) ST.put $ MarshalState e builder' count' pad :: Word8 -> Marshal pad count = do (MarshalState _ _ existing) <- ST.get let padding' = fromIntegral $ padding existing count append $ L.replicate padding' 0 marshalBuilder :: Word8 -> (a -> B.Builder) -> (a -> B.Builder) -> a -> Marshal marshalBuilder size be le x = do pad size (MarshalState e builder count) <- ST.get let builder' = B.append builder $ case e of BigEndian -> be x LittleEndian -> le x let count' = count + (fromIntegral size) ST.put $ MarshalState e builder' count' data MarshalError = MessageTooLong Word64 | ArrayTooLong Word64 | InvalidBodySignature Text | InvalidVariantSignature Text | InvalidText Text deriving (Eq) instance Show MarshalError where show (MessageTooLong x) = concat ["Message too long (", show x, " bytes)."] show (ArrayTooLong x) = concat ["Array too long (", show x, " bytes)."] show (InvalidBodySignature x) = concat ["Invalid body signature: ", show x] show (InvalidVariantSignature x) = concat ["Invalid variant signature: ", show x] show (InvalidText x) = concat ["Text cannot be marshaled: ", show x] instance E.Error MarshalError data UnmarshalState = UnmarshalState Endianness L.ByteString !Word64 newtype Unmarshal a = Unmarshal (E.ErrorT UnmarshalError (ST.State UnmarshalState) a) deriving (Monad, Functor, E.MonadError UnmarshalError, ST.MonadState UnmarshalState) runUnmarshal :: Unmarshal a -> Endianness -> L.ByteString -> Either UnmarshalError a runUnmarshal (Unmarshal m) e bytes = ST.evalState (E.runErrorT m) state where state = UnmarshalState e bytes 0 unmarshal :: T.Signature -> Unmarshal [T.Variant] unmarshal = mapM unmarshalType . T.signatureTypes unmarshalType :: T.Type -> Unmarshal T.Variant unmarshalType T.DBusByte = fmap (T.toVariant . L.head) $ consume 1 unmarshalType T.DBusWord16 = unmarshalGet' 2 G.getWord16be G.getWord16le unmarshalType T.DBusWord32 = unmarshalGet' 4 G.getWord32be G.getWord32le unmarshalType T.DBusWord64 = unmarshalGet' 8 G.getWord64be G.getWord64le unmarshalType T.DBusInt16 = do x <- unmarshalGet 2 G.getWord16be G.getWord16le return . T.toVariant $ (fromIntegral x :: Int16) unmarshalType T.DBusInt32 = do x <- unmarshalGet 4 G.getWord32be G.getWord32le return . T.toVariant $ (fromIntegral x :: Int32) unmarshalType T.DBusInt64 = do x <- unmarshalGet 8 G.getWord64be G.getWord64le return . T.toVariant $ (fromIntegral x :: Int64) unmarshalType T.DBusDouble = unmarshalGet' 8 IEEE.getFloat64be IEEE.getFloat64le unmarshalType T.DBusBoolean = unmarshalWord32 >>= fromMaybeU' "boolean" (\x -> case x of 0 -> Just False 1 -> Just True _ -> Nothing) unmarshalType T.DBusString = fmap T.toVariant unmarshalText unmarshalType T.DBusObjectPath = unmarshalText >>= fromMaybeU' "object path" T.mkObjectPath unmarshalType T.DBusSignature = fmap T.toVariant unmarshalSignature unmarshalType (T.DBusArray t) = T.toVariant `fmap` unmarshalArray t unmarshalType (T.DBusDictionary kt vt) = do let pairType = T.DBusStructure [kt, vt] array <- unmarshalArray pairType fromMaybeU' "dictionary" T.arrayToDictionary array unmarshalType (T.DBusStructure ts) = do skipPadding 8 fmap (T.toVariant . T.Structure) $ mapM unmarshalType ts unmarshalType T.DBusVariant = do let getType sig = case T.signatureTypes sig of [t] -> Just t _ -> Nothing t <- fromMaybeU "variant signature" getType =<< unmarshalSignature T.toVariant `fmap` unmarshalType t consume :: Word64 -> Unmarshal L.ByteString consume count = do (UnmarshalState e bytes offset) <- ST.get let (x, bytes') = L.splitAt (fromIntegral count) bytes unless (L.length x == fromIntegral count) $ E.throwError $ UnexpectedEOF offset ST.put $ UnmarshalState e bytes' (offset + count) return x skipPadding :: Word8 -> Unmarshal () skipPadding count = do (UnmarshalState _ _ offset) <- ST.get bytes <- consume $ padding offset count unless (L.all (== 0) bytes) $ E.throwError $ InvalidPadding offset skipTerminator :: Unmarshal () skipTerminator = do (UnmarshalState _ _ offset) <- ST.get bytes <- consume 1 unless (L.all (== 0) bytes) $ E.throwError $ MissingTerminator offset fromMaybeU :: Show a => Text -> (a -> Maybe b) -> a -> Unmarshal b fromMaybeU label f x = case f x of Just x' -> return x' Nothing -> E.throwError . Invalid label . TL.pack . show $ x fromMaybeU' :: (Show a, T.Variable b) => Text -> (a -> Maybe b) -> a -> Unmarshal T.Variant fromMaybeU' label f x = do x' <- fromMaybeU label f x return $ T.toVariant x' unmarshalGet :: Word8 -> G.Get a -> G.Get a -> Unmarshal a unmarshalGet count be le = do skipPadding count (UnmarshalState e _ _) <- ST.get bs <- consume . fromIntegral $ count let get' = case e of BigEndian -> be LittleEndian -> le return $ G.runGet get' bs unmarshalGet' :: T.Variable a => Word8 -> G.Get a -> G.Get a -> Unmarshal T.Variant unmarshalGet' count be le = T.toVariant `fmap` unmarshalGet count be le untilM :: Monad m => m Bool -> m a -> m [a] untilM test comp = do done <- test if done then return [] else do x <- comp xs <- untilM test comp return $ x:xs data UnmarshalError = UnsupportedProtocolVersion Word8 | UnexpectedEOF Word64 | Invalid Text Text | MissingHeaderField Text | InvalidHeaderField Text T.Variant | InvalidPadding Word64 | MissingTerminator Word64 | ArraySizeMismatch deriving (Eq) instance Show UnmarshalError where show (UnsupportedProtocolVersion x) = concat ["Unsupported protocol version: ", show x] show (UnexpectedEOF pos) = concat ["Unexpected EOF at position ", show pos] show (Invalid label x) = TL.unpack $ TL.concat ["Invalid ", label, ": ", x] show (MissingHeaderField x) = concat ["Required field " , show x , " is missing."] show (InvalidHeaderField x got) = concat [ "Invalid header field ", show x, ": ", show got] show (InvalidPadding pos) = concat ["Invalid padding at position ", show pos] show (MissingTerminator pos) = concat ["Missing NUL terminator at position ", show pos] show ArraySizeMismatch = "Array size mismatch" instance E.Error UnmarshalError marshalWord32 :: Word32 -> Marshal marshalWord32 = marshalBuilder 4 B.putWord32be B.putWord32le unmarshalWord32 :: Unmarshal Word32 unmarshalWord32 = unmarshalGet 4 G.getWord32be G.getWord32le excToMaybe :: a -> Maybe a excToMaybe x = unsafePerformIO $ fmap Just (Exc.evaluate x) `Exc.catch` unicodeError unicodeError :: UnicodeException -> IO (Maybe a) unicodeError = const $ return Nothing maybeEncodeUtf8 :: Text -> Maybe L.ByteString maybeEncodeUtf8 = excToMaybe . encodeUtf8 maybeDecodeUtf8 :: L.ByteString -> Maybe Text maybeDecodeUtf8 = excToMaybe . decodeUtf8 marshalText :: Text -> Marshal marshalText x = do bytes <- case maybeEncodeUtf8 x of Just x' -> return x' Nothing -> E.throwError $ InvalidText x when (L.any (== 0) bytes) $ E.throwError $ InvalidText x marshalWord32 . fromIntegral . L.length $ bytes append bytes append (L.singleton 0) unmarshalText :: Unmarshal Text unmarshalText = do byteCount <- unmarshalWord32 bytes <- consume . fromIntegral $ byteCount skipTerminator fromMaybeU "text" maybeDecodeUtf8 bytes marshalSignature :: T.Signature -> Marshal marshalSignature x = do let bytes = encodeUtf8 . T.strSignature $ x let size = fromIntegral . L.length $ bytes append (L.singleton size) append bytes append (L.singleton 0) unmarshalSignature :: Unmarshal T.Signature unmarshalSignature = do byteCount <- L.head `fmap` consume 1 bytes <- consume $ fromIntegral byteCount sigText <- fromMaybeU "text" maybeDecodeUtf8 bytes skipTerminator fromMaybeU "signature" T.mkSignature sigText marshalArray :: T.Array -> Marshal marshalArray x = do (arrayPadding, arrayBytes) <- getArrayBytes (T.arrayType x) x let arrayLen = L.length arrayBytes when (arrayLen > fromIntegral C.arrayMaximumLength) (E.throwError $ ArrayTooLong $ fromIntegral arrayLen) marshalWord32 $ fromIntegral arrayLen append $ L.replicate arrayPadding 0 append arrayBytes getArrayBytes :: T.Type -> T.Array -> MarshalM (Int64, L.ByteString) getArrayBytes T.DBusByte x = return (0, bytes) where Just bytes = T.arrayToBytes x getArrayBytes itemType x = do let vs = T.arrayItems x s <- ST.get (MarshalState _ _ afterLength) <- marshalWord32 0 >> ST.get (MarshalState e _ afterPadding) <- pad (alignment itemType) >> ST.get ST.put $ MarshalState e B.empty afterPadding (MarshalState _ itemBuilder _) <- mapM_ marshal vs >> ST.get let itemBytes = B.toLazyByteString itemBuilder paddingSize = fromIntegral $ afterPadding - afterLength ST.put s return (paddingSize, itemBytes) unmarshalArray :: T.Type -> Unmarshal T.Array unmarshalArray T.DBusByte = do byteCount <- unmarshalWord32 T.arrayFromBytes `fmap` consume (fromIntegral byteCount) unmarshalArray itemType = do let getOffset = do (UnmarshalState _ _ o) <- ST.get return o byteCount <- unmarshalWord32 skipPadding (alignment itemType) start <- getOffset let end = start + fromIntegral byteCount vs <- untilM (fmap (>= end) getOffset) (unmarshalType itemType) end' <- getOffset when (end' > end) $ E.throwError ArraySizeMismatch fromMaybeU "array" (T.arrayFromItems itemType) vs encodeFlags :: Set.Set M.Flag -> Word8 encodeFlags flags = foldr (.|.) 0 $ map flagValue $ Set.toList flags where flagValue M.NoReplyExpected = 0x1 flagValue M.NoAutoStart = 0x2 decodeFlags :: Word8 -> Set.Set M.Flag decodeFlags word = Set.fromList flags where flagSet = [ (0x1, M.NoReplyExpected) , (0x2, M.NoAutoStart) ] flags = flagSet >>= \(x, y) -> [y | word .&. x > 0] encodeField :: M.HeaderField -> T.Structure encodeField (M.Path x) = encodeField' 1 x encodeField (M.Interface x) = encodeField' 2 x encodeField (M.Member x) = encodeField' 3 x encodeField (M.ErrorName x) = encodeField' 4 x encodeField (M.ReplySerial x) = encodeField' 5 x encodeField (M.Destination x) = encodeField' 6 x encodeField (M.Sender x) = encodeField' 7 x encodeField (M.Signature x) = encodeField' 8 x encodeField' :: T.Variable a => Word8 -> a -> T.Structure encodeField' code x = T.Structure [ T.toVariant code , T.toVariant $ T.toVariant x ] decodeField :: Monad m => T.Structure -> E.ErrorT UnmarshalError m [M.HeaderField] decodeField struct = case unpackField struct of (1, x) -> decodeField' x M.Path "path" (2, x) -> decodeField' x M.Interface "interface" (3, x) -> decodeField' x M.Member "member" (4, x) -> decodeField' x M.ErrorName "error name" (5, x) -> decodeField' x M.ReplySerial "reply serial" (6, x) -> decodeField' x M.Destination "destination" (7, x) -> decodeField' x M.Sender "sender" (8, x) -> decodeField' x M.Signature "signature" _ -> return [] decodeField' :: (Monad m, T.Variable a) => T.Variant -> (a -> b) -> Text -> E.ErrorT UnmarshalError m [b] decodeField' x f label = case T.fromVariant x of Just x' -> return [f x'] Nothing -> E.throwError $ InvalidHeaderField label x unpackField :: T.Structure -> (Word8, T.Variant) unpackField struct = (c', v') where T.Structure [c, v] = struct c' = fromJust . T.fromVariant $ c v' = fromJust . T.fromVariant $ v marshalMessage :: M.Message a => Endianness -> M.Serial -> a -> Either MarshalError L.ByteString marshalMessage e serial msg = runMarshal marshaler e where body = M.messageBody msg marshaler = do sig <- checkBodySig body empty <- ST.get mapM_ marshal body (MarshalState _ bodyBytesB _) <- ST.get ST.put empty marshalEndianness e let bodyBytes = B.toLazyByteString bodyBytesB marshalHeader msg serial sig $ fromIntegral . L.length $ bodyBytes pad 8 append bodyBytes checkMaximumSize checkBodySig :: [T.Variant] -> MarshalM T.Signature checkBodySig vs = let sigStr = TL.concat . map (T.typeCode . T.variantType) $ vs invalid = E.throwError $ InvalidBodySignature sigStr in case T.mkSignature sigStr of Just x -> return x Nothing -> invalid marshalHeader :: M.Message a => a -> M.Serial -> T.Signature -> Word32 -> Marshal marshalHeader msg serial bodySig bodyLength = do let fields = M.Signature bodySig : M.messageHeaderFields msg marshal . T.toVariant . M.messageTypeCode $ msg marshal . T.toVariant . encodeFlags . M.messageFlags $ msg marshal . T.toVariant $ C.protocolVersion marshalWord32 bodyLength marshal . T.toVariant $ serial let fieldType = T.DBusStructure [T.DBusByte, T.DBusVariant] marshal . T.toVariant . fromJust . T.toArray fieldType $ map encodeField fields marshalEndianness :: Endianness -> Marshal marshalEndianness = marshal . T.toVariant . encodeEndianness checkMaximumSize :: Marshal checkMaximumSize = do (MarshalState _ _ messageLength) <- ST.get when (messageLength > fromIntegral C.messageMaximumLength) (E.throwError $ MessageTooLong $ fromIntegral messageLength) unmarshalMessage :: Monad m => (Word32 -> m L.ByteString) -> m (Either UnmarshalError M.ReceivedMessage) unmarshalMessage getBytes' = E.runErrorT $ do let getBytes = E.lift . getBytes' let fixedSig = T.mkSignature' "yyyyuuu" fixedBytes <- getBytes 16 let messageVersion = L.index fixedBytes 3 when (messageVersion /= C.protocolVersion) $ E.throwError $ UnsupportedProtocolVersion messageVersion let eByte = L.index fixedBytes 0 endianness <- case decodeEndianness eByte of Just x' -> return x' Nothing -> E.throwError . Invalid "endianness" . TL.pack . show $ eByte let unmarshal' x bytes = case runUnmarshal (unmarshal x) endianness bytes of Right x' -> return x' Left e -> E.throwError e fixed <- unmarshal' fixedSig fixedBytes let typeCode = fromJust . T.fromVariant $ fixed !! 1 let flags = decodeFlags . fromJust . T.fromVariant $ fixed !! 2 let bodyLength = fromJust . T.fromVariant $ fixed !! 4 let serial = fromJust . T.fromVariant $ fixed !! 5 let fieldByteCount = fromJust . T.fromVariant $ fixed !! 6 let headerSig = T.mkSignature' "yyyyuua(yv)" fieldBytes <- getBytes fieldByteCount let headerBytes = L.append fixedBytes fieldBytes header <- unmarshal' headerSig headerBytes let fieldArray = fromJust . T.fromVariant $ header !! 6 let fieldStructures = fromJust . T.fromArray $ fieldArray fields <- concat `fmap` mapM decodeField fieldStructures let bodyPadding = padding (fromIntegral fieldByteCount + 16) 8 getBytes . fromIntegral $ bodyPadding let bodySig = findBodySignature fields bodyBytes <- getBytes bodyLength body <- unmarshal' bodySig bodyBytes y <- case buildReceivedMessage typeCode fields of Right x -> return x Left x -> E.throwError $ MissingHeaderField x return $ y serial flags body findBodySignature :: [M.HeaderField] -> T.Signature findBodySignature fields = fromMaybe empty signature where empty = T.mkSignature' "" signature = listToMaybe [x | M.Signature x <- fields] buildReceivedMessage :: Word8 -> [M.HeaderField] -> Either Text (M.Serial -> (Set.Set M.Flag) -> [T.Variant] -> M.ReceivedMessage) buildReceivedMessage 1 fields = do path <- require "path" [x | M.Path x <- fields] member <- require "member name" [x | M.Member x <- fields] return $ \serial flags body -> let iface = listToMaybe [x | M.Interface x <- fields] dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.MethodCall path member iface dest flags body in M.ReceivedMethodCall serial sender msg buildReceivedMessage 2 fields = do replySerial <- require "reply serial" [x | M.ReplySerial x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.MethodReturn replySerial dest body in M.ReceivedMethodReturn serial sender msg buildReceivedMessage 3 fields = do name <- require "error name" [x | M.ErrorName x <- fields] replySerial <- require "reply serial" [x | M.ReplySerial x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.Error name replySerial dest body in M.ReceivedError serial sender msg buildReceivedMessage 4 fields = do path <- require "path" [x | M.Path x <- fields] member <- require "member name" [x | M.Member x <- fields] iface <- require "interface" [x | M.Interface x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.Signal path member iface dest body in M.ReceivedSignal serial sender msg buildReceivedMessage typeCode fields = return $ \serial flags body -> let sender = listToMaybe [x | M.Sender x <- fields] msg = M.Unknown typeCode flags body in M.ReceivedUnknown serial sender msg require :: Text -> [a] -> Either Text a require _ (x:_) = Right x require label _ = Left label instance E.Error Text where strMsg = TL.pack