-- Copyright (C) 2009-2010 John Millikin -- -- This program is free software: you can redistribute it and/or modify -- it under the terms of the GNU General Public License as published by -- the Free Software Foundation, either version 3 of the License, or -- any later version. -- -- This program is distributed in the hope that it will be useful, -- but WITHOUT ANY WARRANTY; without even the implied warranty of -- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -- GNU General Public License for more details. -- -- You should have received a copy of the GNU General Public License -- along with this program. If not, see . {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} module DBus.Wire.Unmarshal where import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as TL import qualified Control.Monad.State as State import Control.Monad.Trans.Class (lift) import qualified DBus.Util.MonadError as E import qualified Data.ByteString.Lazy as L import qualified Data.Binary.Get as G import qualified Data.Binary.IEEE754 as IEEE import DBus.Wire.Unicode (maybeDecodeUtf8) import qualified DBus.Message.Internal as M import Data.Bits ((.&.)) import qualified Data.Set as Set import qualified DBus.Constants as C import Control.Monad (when, unless, liftM) import Data.Maybe (fromJust, listToMaybe, fromMaybe) import Data.Word (Word8, Word32, Word64) import Data.Int (Int16, Int32, Int64) import DBus.Wire.Internal import qualified DBus.Types as T data UnmarshalState = UnmarshalState Endianness L.ByteString !Word64 type Unmarshal = E.ErrorT UnmarshalError (State.State UnmarshalState) runUnmarshal :: Unmarshal a -> Endianness -> L.ByteString -> Either UnmarshalError a runUnmarshal m e bytes = State.evalState (E.runErrorT m) state where state = UnmarshalState e bytes 0 unmarshal :: T.Signature -> Unmarshal [T.Variant] unmarshal = mapM unmarshalType . T.signatureTypes unmarshalType :: T.Type -> Unmarshal T.Variant unmarshalType T.DBusByte = fmap (T.toVariant . L.head) $ consume 1 unmarshalType T.DBusWord16 = unmarshalGet' 2 G.getWord16be G.getWord16le unmarshalType T.DBusWord32 = unmarshalGet' 4 G.getWord32be G.getWord32le unmarshalType T.DBusWord64 = unmarshalGet' 8 G.getWord64be G.getWord64le unmarshalType T.DBusInt16 = do x <- unmarshalGet 2 G.getWord16be G.getWord16le return . T.toVariant $ (fromIntegral x :: Int16) unmarshalType T.DBusInt32 = do x <- unmarshalGet 4 G.getWord32be G.getWord32le return . T.toVariant $ (fromIntegral x :: Int32) unmarshalType T.DBusInt64 = do x <- unmarshalGet 8 G.getWord64be G.getWord64le return . T.toVariant $ (fromIntegral x :: Int64) unmarshalType T.DBusDouble = unmarshalGet' 8 IEEE.getFloat64be IEEE.getFloat64le unmarshalType T.DBusBoolean = unmarshalWord32 >>= fromMaybeU' "boolean" (\x -> case x of 0 -> Just False 1 -> Just True _ -> Nothing) unmarshalType T.DBusString = fmap T.toVariant unmarshalText unmarshalType T.DBusObjectPath = unmarshalText >>= fromMaybeU' "object path" T.mkObjectPath unmarshalType T.DBusSignature = fmap T.toVariant unmarshalSignature unmarshalType (T.DBusArray t) = T.toVariant `fmap` unmarshalArray t unmarshalType (T.DBusDictionary kt vt) = do let pairType = T.DBusStructure [kt, vt] array <- unmarshalArray pairType fromMaybeU' "dictionary" T.arrayToDictionary array unmarshalType (T.DBusStructure ts) = do skipPadding 8 fmap (T.toVariant . T.Structure) $ mapM unmarshalType ts unmarshalType T.DBusVariant = do let getType sig = case T.signatureTypes sig of [t] -> Just t _ -> Nothing t <- fromMaybeU "variant signature" getType =<< unmarshalSignature T.toVariant `fmap` unmarshalType t consume :: Word64 -> Unmarshal L.ByteString consume count = do (UnmarshalState e bytes offset) <- State.get let (x, bytes') = L.splitAt (fromIntegral count) bytes unless (L.length x == fromIntegral count) $ E.throwError $ UnexpectedEOF offset State.put $ UnmarshalState e bytes' (offset + count) return x skipPadding :: Word8 -> Unmarshal () skipPadding count = do (UnmarshalState _ _ offset) <- State.get bytes <- consume $ padding offset count unless (L.all (== 0) bytes) $ E.throwError $ InvalidPadding offset skipTerminator :: Unmarshal () skipTerminator = do (UnmarshalState _ _ offset) <- State.get bytes <- consume 1 unless (L.all (== 0) bytes) $ E.throwError $ MissingTerminator offset fromMaybeU :: Show a => Text -> (a -> Maybe b) -> a -> Unmarshal b fromMaybeU label f x = case f x of Just x' -> return x' Nothing -> E.throwError . Invalid label . TL.pack . show $ x fromMaybeU' :: (Show a, T.Variable b) => Text -> (a -> Maybe b) -> a -> Unmarshal T.Variant fromMaybeU' label f x = do x' <- fromMaybeU label f x return $ T.toVariant x' unmarshalGet :: Word8 -> G.Get a -> G.Get a -> Unmarshal a unmarshalGet count be le = do skipPadding count (UnmarshalState e _ _) <- State.get bs <- consume . fromIntegral $ count let get' = case e of BigEndian -> be LittleEndian -> le return $ G.runGet get' bs unmarshalGet' :: T.Variable a => Word8 -> G.Get a -> G.Get a -> Unmarshal T.Variant unmarshalGet' count be le = T.toVariant `fmap` unmarshalGet count be le untilM :: Monad m => m Bool -> m a -> m [a] untilM test comp = do done <- test if done then return [] else do x <- comp xs <- untilM test comp return $ x:xs data UnmarshalError = UnsupportedProtocolVersion Word8 | UnexpectedEOF Word64 | Invalid Text Text | MissingHeaderField Text | InvalidHeaderField Text T.Variant | InvalidPadding Word64 | MissingTerminator Word64 | ArraySizeMismatch deriving (Eq) instance Show UnmarshalError where show (UnsupportedProtocolVersion x) = concat ["Unsupported protocol version: ", show x] show (UnexpectedEOF pos) = concat ["Unexpected EOF at position ", show pos] show (Invalid label x) = TL.unpack $ TL.concat ["Invalid ", label, ": ", x] show (MissingHeaderField x) = concat ["Required field " , show x , " is missing."] show (InvalidHeaderField x got) = concat [ "Invalid header field ", show x, ": ", show got] show (InvalidPadding pos) = concat ["Invalid padding at position ", show pos] show (MissingTerminator pos) = concat ["Missing NUL terminator at position ", show pos] show ArraySizeMismatch = "Array size mismatch" unmarshalWord32 :: Unmarshal Word32 unmarshalWord32 = unmarshalGet 4 G.getWord32be G.getWord32le unmarshalText :: Unmarshal Text unmarshalText = do byteCount <- unmarshalWord32 bytes <- consume . fromIntegral $ byteCount skipTerminator fromMaybeU "text" maybeDecodeUtf8 bytes unmarshalSignature :: Unmarshal T.Signature unmarshalSignature = do byteCount <- L.head `fmap` consume 1 bytes <- consume $ fromIntegral byteCount sigText <- fromMaybeU "text" maybeDecodeUtf8 bytes skipTerminator fromMaybeU "signature" T.mkSignature sigText unmarshalArray :: T.Type -> Unmarshal T.Array unmarshalArray T.DBusByte = do byteCount <- unmarshalWord32 T.arrayFromBytes `fmap` consume (fromIntegral byteCount) unmarshalArray itemType = do let getOffset = do (UnmarshalState _ _ o) <- State.get return o byteCount <- unmarshalWord32 skipPadding (alignment itemType) start <- getOffset let end = start + fromIntegral byteCount vs <- untilM (fmap (>= end) getOffset) (unmarshalType itemType) end' <- getOffset when (end' > end) $ E.throwError ArraySizeMismatch fromMaybeU "array" (T.arrayFromItems itemType) vs decodeFlags :: Word8 -> Set.Set M.Flag decodeFlags word = Set.fromList flags where flagSet = [ (0x1, M.NoReplyExpected) , (0x2, M.NoAutoStart) ] flags = flagSet >>= \(x, y) -> [y | word .&. x > 0] decodeField :: Monad m => T.Structure -> E.ErrorT UnmarshalError m [M.HeaderField] decodeField struct = case unpackField struct of (1, x) -> decodeField' x M.Path "path" (2, x) -> decodeField' x M.Interface "interface" (3, x) -> decodeField' x M.Member "member" (4, x) -> decodeField' x M.ErrorName "error name" (5, x) -> decodeField' x M.ReplySerial "reply serial" (6, x) -> decodeField' x M.Destination "destination" (7, x) -> decodeField' x M.Sender "sender" (8, x) -> decodeField' x M.Signature "signature" _ -> return [] decodeField' :: (Monad m, T.Variable a) => T.Variant -> (a -> b) -> Text -> E.ErrorT UnmarshalError m [b] decodeField' x f label = case T.fromVariant x of Just x' -> return [f x'] Nothing -> E.throwError $ InvalidHeaderField label x unpackField :: T.Structure -> (Word8, T.Variant) unpackField struct = (c', v') where T.Structure [c, v] = struct c' = fromJust . T.fromVariant $ c v' = fromJust . T.fromVariant $ v -- | Read bytes from a monad until a complete message has been received. unmarshalMessage :: Monad m => (Word32 -> m L.ByteString) -> m (Either UnmarshalError M.ReceivedMessage) unmarshalMessage getBytes' = E.runErrorT $ do let getBytes = lift . getBytes' let fixedSig = "yyyyuuu" fixedBytes <- getBytes 16 let messageVersion = L.index fixedBytes 3 when (messageVersion /= C.protocolVersion) $ E.throwError $ UnsupportedProtocolVersion messageVersion let eByte = L.index fixedBytes 0 endianness <- case decodeEndianness eByte of Just x' -> return x' Nothing -> E.throwError . Invalid "endianness" . TL.pack . show $ eByte let unmarshal' x bytes = case runUnmarshal (unmarshal x) endianness bytes of Right x' -> return x' Left e -> E.throwError e fixed <- unmarshal' fixedSig fixedBytes let typeCode = fromJust . T.fromVariant $ fixed !! 1 let flags = decodeFlags . fromJust . T.fromVariant $ fixed !! 2 let bodyLength = fromJust . T.fromVariant $ fixed !! 4 let serial = fromJust . T.fromVariant $ fixed !! 5 let fieldByteCount = fromJust . T.fromVariant $ fixed !! 6 let headerSig = "yyyyuua(yv)" fieldBytes <- getBytes fieldByteCount let headerBytes = L.append fixedBytes fieldBytes header <- unmarshal' headerSig headerBytes let fieldArray = fromJust . T.fromVariant $ header !! 6 let fieldStructures = fromJust . T.fromArray $ fieldArray fields <- concat `liftM` mapM decodeField fieldStructures let bodyPadding = padding (fromIntegral fieldByteCount + 16) 8 getBytes . fromIntegral $ bodyPadding let bodySig = findBodySignature fields bodyBytes <- getBytes bodyLength body <- unmarshal' bodySig bodyBytes y <- case buildReceivedMessage typeCode fields of EitherM (Right x) -> return x EitherM (Left x) -> E.throwError $ MissingHeaderField x return $ y serial flags body findBodySignature :: [M.HeaderField] -> T.Signature findBodySignature fields = fromMaybe "" signature where signature = listToMaybe [x | M.Signature x <- fields] newtype EitherM a b = EitherM (Either a b) instance Monad (EitherM a) where return = EitherM . Right (EitherM (Left x)) >>= _ = EitherM (Left x) (EitherM (Right x)) >>= k = k x buildReceivedMessage :: Word8 -> [M.HeaderField] -> EitherM Text (M.Serial -> (Set.Set M.Flag) -> [T.Variant] -> M.ReceivedMessage) buildReceivedMessage 1 fields = do path <- require "path" [x | M.Path x <- fields] member <- require "member name" [x | M.Member x <- fields] return $ \serial flags body -> let iface = listToMaybe [x | M.Interface x <- fields] dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.MethodCall path member iface dest flags body in M.ReceivedMethodCall serial sender msg buildReceivedMessage 2 fields = do replySerial <- require "reply serial" [x | M.ReplySerial x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.MethodReturn replySerial dest body in M.ReceivedMethodReturn serial sender msg buildReceivedMessage 3 fields = do name <- require "error name" [x | M.ErrorName x <- fields] replySerial <- require "reply serial" [x | M.ReplySerial x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.Error name replySerial dest body in M.ReceivedError serial sender msg buildReceivedMessage 4 fields = do path <- require "path" [x | M.Path x <- fields] member <- require "member name" [x | M.Member x <- fields] iface <- require "interface" [x | M.Interface x <- fields] return $ \serial _ body -> let dest = listToMaybe [x | M.Destination x <- fields] sender = listToMaybe [x | M.Sender x <- fields] msg = M.Signal path member iface dest body in M.ReceivedSignal serial sender msg buildReceivedMessage typeCode fields = return $ \serial flags body -> let sender = listToMaybe [x | M.Sender x <- fields] msg = M.Unknown typeCode flags body in M.ReceivedUnknown serial sender msg require :: Text -> [a] -> EitherM Text a require _ (x:_) = return x require label _ = EitherM $ Left label