{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{-# LANGUAGE DeriveDataTypeable #-}
module DBus.Wire (  Endianness (..)

                  , alignment

                  , MarshalM
                  , Marshal
                  , marshal
                  , runMarshal

                  , MarshalError (..)

                  , Unmarshal
                  , unmarshal
                  , runUnmarshal

                  , UnmarshalError (..)

                  , marshalMessage

                  , unmarshalMessage
) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Monad.State as ST
import qualified Control.Monad.Error as E
import qualified Data.ByteString.Lazy as L

import qualified Data.Binary.Put as P

import qualified Data.Binary.Get as G

import qualified Data.Binary.IEEE754 as IEEE

import Data.Text.Lazy.Encoding (encodeUtf8, decodeUtf8)

import qualified DBus.Constants as C

import qualified DBus.Message.Internal as M

import Data.Bits ((.|.), (.&.))
import qualified Data.Set as Set

import Control.Monad (when, unless)
import Data.Maybe (fromJust, listToMaybe, fromMaybe)
import Data.Word (Word8, Word32, Word64)
import Data.Int (Int16, Int32, Int64)
import Data.Typeable (Typeable)

import qualified DBus.Types as T

data Endianness = LittleEndian | BigEndian
        deriving (Show, Eq)

encodeEndianness :: Endianness -> Word8
encodeEndianness LittleEndian = 108
encodeEndianness BigEndian    = 66

decodeEndianness :: Word8 -> Maybe Endianness
decodeEndianness 108 = Just LittleEndian
decodeEndianness 66  = Just BigEndian
decodeEndianness _   = Nothing

alignment :: T.Type -> Word8
alignment T.DBusByte    = 1
alignment T.DBusWord16  = 2
alignment T.DBusWord32  = 4
alignment T.DBusWord64  = 8
alignment T.DBusInt16   = 2
alignment T.DBusInt32   = 4
alignment T.DBusInt64   = 8
alignment T.DBusDouble  = 8

alignment T.DBusBoolean = 4

alignment T.DBusString     = 4
alignment T.DBusObjectPath = 4

alignment T.DBusSignature  = 1

alignment (T.DBusArray _) = 4

alignment (T.DBusDictionary _ _) = 4

alignment (T.DBusStructure _) = 8

alignment T.DBusVariant = 1


padding :: Word64 -> Word8 -> Word64
padding current count = required where
        count' = fromIntegral count
        missing = mod current count'
        required = if missing > 0
                then count' - missing
                else 0

data MarshalState = MarshalState Endianness L.ByteString
newtype MarshalM a = MarshalM (E.ErrorT MarshalError (ST.State MarshalState) a)
        deriving (Monad, E.MonadError MarshalError, ST.MonadState MarshalState)
type Marshal = MarshalM ()

runMarshal :: Marshal -> Endianness -> Either MarshalError L.ByteString
runMarshal (MarshalM m) e = case ST.runState (E.runErrorT m) initialState of
        (Right _, MarshalState _ bytes) -> Right bytes
        (Left  x, _) -> Left x
        where initialState = MarshalState e L.empty

marshal :: T.Variant -> Marshal
marshal v = marshalType (T.variantType v) where
        x :: T.Variable a => a
        x = fromJust . T.fromVariant $ v
        marshalType :: T.Type -> Marshal
        marshalType T.DBusByte   = append $ L.singleton x
        marshalType T.DBusWord16 = marshalPut P.putWord16be x
        marshalType T.DBusWord32 = marshalPut P.putWord32be x
        marshalType T.DBusWord64 = marshalPut P.putWord64be x
        marshalType T.DBusInt16  = marshalPut P.putWord16be $ fromIntegral (x :: Int16)
        marshalType T.DBusInt32  = marshalPut P.putWord32be $ fromIntegral (x :: Int32)
        marshalType T.DBusInt64  = marshalPut P.putWord64be $ fromIntegral (x :: Int64)

        marshalType T.DBusDouble = marshalPut IEEE.putFloat64be x

        marshalType T.DBusBoolean = marshalWord32 $ if x then 1 else 0

        marshalType T.DBusString = marshalText x
        marshalType T.DBusObjectPath = marshalText . T.strObjectPath $ x

        marshalType T.DBusSignature = marshalSignature x

        marshalType (T.DBusArray _) = marshalArray x

        marshalType (T.DBusDictionary _ _) = marshalArray (T.dictionaryToArray x)

        marshalType (T.DBusStructure _) = do
                let T.Structure vs = x
                pad 8
                mapM_ marshal vs

        marshalType T.DBusVariant = do
                let rawSig = T.typeCode . T.variantType $ x
                sig <- case T.mkSignature rawSig of
                        Just x' -> return x'
                        Nothing -> E.throwError $ InvalidVariantSignature rawSig
                marshalSignature sig
                marshal x


append :: L.ByteString -> Marshal
append bs = do
        (MarshalState e bs') <- ST.get
        ST.put $ MarshalState e (L.append bs' bs)

pad :: Word8 -> Marshal
pad count = do
        (MarshalState _ bytes) <- ST.get
        let padding' = padding (fromIntegral . L.length $ bytes) count
        append $ L.replicate (fromIntegral padding') 0

marshalPut :: (a -> P.Put) -> a -> Marshal
marshalPut put x = do
        let bytes = P.runPut $ put x
        (MarshalState e _) <- ST.get
        pad . fromIntegral . L.length $ bytes
        append $ case e of
                BigEndian -> bytes
                LittleEndian -> L.reverse bytes

data MarshalError
        = MessageTooLong Word64
        | ArrayTooLong Word64
        | InvalidBodySignature Text
        | InvalidVariantSignature Text
        deriving (Eq, Typeable)

instance Show MarshalError where
        show (MessageTooLong x) = concat
                ["Message too long (", show x, " bytes)."]
        show (ArrayTooLong x) = concat
                ["Array too long (", show x, " bytes)."]
        show (InvalidBodySignature x) = concat
                ["Invalid body signature: ", show x]
        show (InvalidVariantSignature x) = concat
                ["Invalid variant signature: ", show x]

instance E.Error MarshalError

data UnmarshalState = UnmarshalState Endianness L.ByteString Word64
newtype Unmarshal a = Unmarshal (E.ErrorT UnmarshalError (ST.State UnmarshalState) a)
        deriving (Monad, Functor, E.MonadError UnmarshalError, ST.MonadState UnmarshalState)

runUnmarshal :: Unmarshal a -> Endianness -> L.ByteString -> Either UnmarshalError a
runUnmarshal (Unmarshal m) e bytes = ST.evalState (E.runErrorT m) state where
        state = UnmarshalState e bytes 0

unmarshal :: T.Signature -> Unmarshal [T.Variant]
unmarshal = mapM unmarshalType . T.signatureTypes

unmarshalType :: T.Type -> Unmarshal T.Variant
unmarshalType T.DBusByte = fmap (T.toVariant . L.head) $ consume 1
unmarshalType T.DBusWord16 = unmarshalGet' 2 G.getWord16be G.getWord16le
unmarshalType T.DBusWord32 = unmarshalGet' 4 G.getWord32be G.getWord32le
unmarshalType T.DBusWord64 = unmarshalGet' 8 G.getWord64be G.getWord64le

unmarshalType T.DBusInt16  = do
        x <- unmarshalGet 2 G.getWord16be G.getWord16le
        return . T.toVariant $ (fromIntegral x :: Int16)

unmarshalType T.DBusInt32  = do
        x <- unmarshalGet 4 G.getWord32be G.getWord32le
        return . T.toVariant $ (fromIntegral x :: Int32)

unmarshalType T.DBusInt64  = do
        x <- unmarshalGet 8 G.getWord64be G.getWord64le
        return . T.toVariant $ (fromIntegral x :: Int64)

unmarshalType T.DBusDouble = unmarshalGet' 8 IEEE.getFloat64be IEEE.getFloat64le

unmarshalType T.DBusBoolean = unmarshalWord32 >>=
        fromMaybeU' "boolean" (\x -> case x of
                0 -> Just False
                1 -> Just True
                _ -> Nothing)

unmarshalType T.DBusString = fmap T.toVariant unmarshalText

unmarshalType T.DBusObjectPath = unmarshalText >>=
        fromMaybeU' "object path" T.mkObjectPath

unmarshalType T.DBusSignature = fmap T.toVariant unmarshalSignature

unmarshalType (T.DBusArray t) = fmap T.toVariant $ unmarshalArray t

unmarshalType (T.DBusDictionary kt vt) = do
        let pairType = T.DBusStructure [kt, vt]
        array <- unmarshalArray pairType
        fromMaybeU' "dictionary" T.arrayToDictionary array

unmarshalType (T.DBusStructure ts) = do
        skipPadding 8
        fmap (T.toVariant . T.Structure) $ mapM unmarshalType ts

unmarshalType T.DBusVariant = do
        let getType sig = case T.signatureTypes sig of
                [t] -> Just t
                _   -> Nothing
        
        t <- fromMaybeU "variant signature" getType =<< unmarshalSignature
        fmap T.toVariant $ unmarshalType t


consume :: Word64 -> Unmarshal L.ByteString
consume count = do
        (UnmarshalState e bytes offset) <- ST.get
        let bytes' = L.drop (fromIntegral offset) bytes
        let x = L.take (fromIntegral count) bytes'
        unless (L.length x == fromIntegral count) $
                E.throwError $ UnexpectedEOF offset
        
        ST.put $ UnmarshalState e bytes (offset + count)
        return x

skipPadding :: Word8 -> Unmarshal ()
skipPadding count = do
        (UnmarshalState _ _ offset) <- ST.get
        bytes <- consume $ padding offset count
        unless (L.all (== 0) bytes) $
                E.throwError $ InvalidPadding offset

skipTerminator :: Unmarshal ()
skipTerminator = do
        (UnmarshalState _ _ offset) <- ST.get
        bytes <- consume 1
        unless (L.all (== 0) bytes) $
                E.throwError $ MissingTerminator offset

fromMaybeU :: Show a => Text -> (a -> Maybe b) -> a -> Unmarshal b
fromMaybeU label f x = case f x of
        Just x' -> return x'
        Nothing -> E.throwError . Invalid label . TL.pack . show $ x

fromMaybeU' :: (Show a, T.Variable b) => Text -> (a -> Maybe b) -> a
           -> Unmarshal T.Variant
fromMaybeU' label f x = do
        x' <- fromMaybeU label f x
        return $ T.toVariant x'

unmarshalGet :: Word8 -> G.Get a -> G.Get a -> Unmarshal a
unmarshalGet count be le = do
        skipPadding count
        (UnmarshalState e _ _) <- ST.get
        bs <- consume . fromIntegral $ count
        let get' = case e of
                BigEndian -> be
                LittleEndian -> le
        return $ G.runGet get' bs

unmarshalGet' :: T.Variable a => Word8 -> G.Get a -> G.Get a
              -> Unmarshal T.Variant
unmarshalGet' count be le = fmap T.toVariant $ unmarshalGet count be le

untilM :: Monad m => m Bool -> m a -> m [a]
untilM test comp = do
        done <- test
        if done
                then return []
                else do
                        x <- comp
                        xs <- untilM test comp
                        return $ x:xs

data UnmarshalError
        = UnsupportedProtocolVersion Word8
        | UnexpectedEOF Word64
        | Invalid Text Text
        | MissingHeaderField Text
        | InvalidHeaderField Text T.Variant
        | InvalidPadding Word64
        | MissingTerminator Word64
        | ArraySizeMismatch
        deriving (Eq, Typeable)

instance Show UnmarshalError where
        show (UnsupportedProtocolVersion x) = concat
                ["Unsupported protocol version: ", show x]
        show (UnexpectedEOF pos) = concat
                ["Unexpected EOF at position ", show pos]
        show (Invalid label x) = TL.unpack $ TL.concat
                ["Invalid ", label, ": ", x]
        show (MissingHeaderField x) = concat
                ["Required field " , show x , " is missing."]
        show (InvalidHeaderField x got) = concat
                [ "Invalid header field ", show x, ": ", show got]
        show (InvalidPadding pos) = concat
                ["Invalid padding at position ", show pos]
        show (MissingTerminator pos) = concat
                ["Missing NUL terminator at position ", show pos]
        show ArraySizeMismatch = "Array size mismatch"

instance E.Error UnmarshalError

marshalWord32 :: Word32 -> Marshal
marshalWord32 = marshalPut P.putWord32be

unmarshalWord32 :: Unmarshal Word32
unmarshalWord32 = unmarshalGet 4 G.getWord32be G.getWord32le

marshalText :: Text -> Marshal
marshalText x = do
        let bytes = encodeUtf8 x
        marshalWord32 . fromIntegral . L.length $ bytes
        append bytes
        append (L.singleton 0)

unmarshalText :: Unmarshal Text
unmarshalText = do
        byteCount <- unmarshalWord32
        bytes <- consume . fromIntegral $ byteCount
        skipTerminator
        return . decodeUtf8 $ bytes

marshalSignature :: T.Signature -> Marshal
marshalSignature x = do
        let bytes = encodeUtf8 . T.strSignature $ x
        let size = fromIntegral . L.length $ bytes
        append (L.singleton size)
        append bytes
        append (L.singleton 0)

unmarshalSignature :: Unmarshal T.Signature
unmarshalSignature = do
        byteCount <- fmap L.head $ consume 1
        sigText <- fmap decodeUtf8 $ consume . fromIntegral $ byteCount
        skipTerminator
        fromMaybeU "signature" T.mkSignature sigText

marshalArray :: T.Array -> Marshal
marshalArray x = do
        (arrayPadding, arrayBytes) <- getArrayBytes x
        let arrayLen = L.length arrayBytes
        when (arrayLen > fromIntegral C.arrayMaximumLength)
                (E.throwError $ ArrayTooLong $ fromIntegral arrayLen)
        marshalWord32 $ fromIntegral arrayLen
        append arrayPadding
        append arrayBytes

getArrayBytes :: T.Array -> MarshalM (L.ByteString, L.ByteString)
getArrayBytes x = do
        let vs = T.arrayItems x
        let itemType = T.arrayType x
        s <- ST.get
        (MarshalState _ afterLength) <- marshalWord32 0 >> ST.get
        (MarshalState _ afterPadding) <- pad (alignment itemType) >> ST.get
        (MarshalState _ afterItems) <- mapM_ marshal vs >> ST.get
        
        let paddingBytes = L.drop (L.length afterLength) afterPadding
        let itemBytes = L.drop (L.length afterPadding) afterItems
        
        ST.put s
        return (paddingBytes, itemBytes)

unmarshalArray :: T.Type -> Unmarshal T.Array
unmarshalArray itemType = do
        let getOffset = do
                (UnmarshalState _ _ o) <- ST.get
                return o
        byteCount <- unmarshalWord32
        skipPadding (alignment itemType)
        start <- getOffset
        let end = start + fromIntegral byteCount
        vs <- untilM (fmap (>= end) getOffset) (unmarshalType itemType)
        end' <- getOffset
        when (end' > end) $
                E.throwError ArraySizeMismatch
        fromMaybeU "array" (T.arrayFromItems itemType) vs

encodeFlags :: Set.Set M.Flag -> Word8
encodeFlags flags = foldr (.|.) 0 $ map flagValue $ Set.toList flags where
        flagValue M.NoReplyExpected = 0x1
        flagValue M.NoAutoStart     = 0x2

decodeFlags :: Word8 -> Set.Set M.Flag
decodeFlags word = Set.fromList flags where
        flagSet = [ (0x1, M.NoReplyExpected)
                  , (0x2, M.NoAutoStart)
                  ]
        flags = flagSet >>= \(x, y) -> [y | word .&. x > 0]

encodeField :: M.HeaderField -> T.Structure
encodeField (M.Path x)        = encodeField' 1 x
encodeField (M.Interface x)   = encodeField' 2 x
encodeField (M.Member x)      = encodeField' 3 x
encodeField (M.ErrorName x)   = encodeField' 4 x
encodeField (M.ReplySerial x) = encodeField' 5 x
encodeField (M.Destination x) = encodeField' 6 x
encodeField (M.Sender x)      = encodeField' 7 x
encodeField (M.Signature x)   = encodeField' 8 x

encodeField' :: T.Variable a => Word8 -> a -> T.Structure
encodeField' code x = T.Structure
        [ T.toVariant code
        , T.toVariant $ T.toVariant x
        ]

decodeField :: Monad m => T.Structure
            -> E.ErrorT UnmarshalError m [M.HeaderField]
decodeField struct = case unpackField struct of
        (1, x) -> decodeField' x M.Path "path"
        (2, x) -> decodeField' x M.Interface "interface"
        (3, x) -> decodeField' x M.Member "member"
        (4, x) -> decodeField' x M.ErrorName "error name"
        (5, x) -> decodeField' x M.ReplySerial "reply serial"
        (6, x) -> decodeField' x M.Destination "destination"
        (7, x) -> decodeField' x M.Sender "sender"
        (8, x) -> decodeField' x M.Signature "signature"
        _      -> return []

decodeField' :: (Monad m, T.Variable a) => T.Variant -> (a -> b) -> Text
             -> E.ErrorT UnmarshalError m [b]
decodeField' x f label = case T.fromVariant x of
        Just x' -> return [f x']
        Nothing -> E.throwError $ InvalidHeaderField label x

unpackField :: T.Structure -> (Word8, T.Variant)
unpackField struct = (c', v') where
        T.Structure [c, v] = struct
        c' = fromJust . T.fromVariant $ c
        v' = fromJust . T.fromVariant $ v

marshalMessage :: M.Message a => Endianness -> M.Serial -> a
               -> Either MarshalError L.ByteString
marshalMessage e serial msg = runMarshal marshaler e where
        body = M.messageBody msg
        marshaler = do
                sig <- checkBodySig body
                empty <- ST.get
                mapM_ marshal body
                (MarshalState _ bodyBytes) <- ST.get
                ST.put empty
                marshalEndianness e
                marshalHeader msg serial sig
                        $ fromIntegral . L.length $ bodyBytes
                pad 8
                append bodyBytes
                checkMaximumSize

checkBodySig :: [T.Variant] -> MarshalM T.Signature
checkBodySig vs = let
        sigStr = TL.concat . map (T.typeCode . T.variantType) $ vs
        invalid = E.throwError $ InvalidBodySignature sigStr
        in case T.mkSignature sigStr of
                Just x -> return x
                Nothing -> invalid

marshalHeader :: M.Message a => a -> M.Serial -> T.Signature -> Word32
              -> Marshal
marshalHeader msg serial bodySig bodyLength = do
        let fields = M.Signature bodySig : M.messageHeaderFields msg
        marshal . T.toVariant . M.messageTypeCode $ msg
        marshal . T.toVariant . encodeFlags . M.messageFlags $ msg
        marshal . T.toVariant $ C.protocolVersion
        marshalWord32 bodyLength
        marshal . T.toVariant $ serial
        let fieldType = T.DBusStructure [T.DBusByte, T.DBusVariant]
        marshal . T.toVariant . fromJust . T.toArray fieldType
                $ map encodeField fields

marshalEndianness :: Endianness -> Marshal
marshalEndianness = marshal . T.toVariant . encodeEndianness

checkMaximumSize :: Marshal
checkMaximumSize = do
        (MarshalState _ messageBytes) <- ST.get
        let messageLength = L.length messageBytes
        when (messageLength > fromIntegral C.messageMaximumLength)
                (E.throwError $ MessageTooLong $ fromIntegral messageLength)

unmarshalMessage :: Monad m => (Word32 -> m L.ByteString)
                 -> m (Either UnmarshalError M.ReceivedMessage)
unmarshalMessage getBytes' = E.runErrorT $ do
        let getBytes = E.lift . getBytes'
        
        let fixedSig = T.mkSignature' "yyyyuuu"
        fixedBytes <- getBytes 16

        let messageVersion = L.index fixedBytes 3
        when (messageVersion /= C.protocolVersion) $
                E.throwError $ UnsupportedProtocolVersion messageVersion

        let eByte = L.index fixedBytes 0
        endianness <- case decodeEndianness eByte of
                Just x' -> return x'
                Nothing -> E.throwError . Invalid "endianness" . TL.pack . show $ eByte

        let unmarshal' x bytes = case runUnmarshal (unmarshal x) endianness bytes of
                Right x' -> return x'
                Left  e  -> E.throwError e
        fixed <- unmarshal' fixedSig fixedBytes
        let typeCode = fromJust . T.fromVariant $ fixed !! 1
        let flags = decodeFlags . fromJust . T.fromVariant $ fixed !! 2
        let bodyLength = fromJust . T.fromVariant $ fixed !! 4
        let serial = fromJust . T.fromVariant $ fixed !! 5

        let fieldByteCount = fromJust . T.fromVariant $ fixed !! 6

        let headerSig  = T.mkSignature' "yyyyuua(yv)"
        fieldBytes <- getBytes fieldByteCount
        let headerBytes = L.append fixedBytes fieldBytes
        header <- unmarshal' headerSig headerBytes

        let fieldArray = fromJust . T.fromVariant $ header !! 6
        let fieldStructures = fromJust . T.fromArray $ fieldArray
        fields <- fmap concat $ mapM decodeField fieldStructures

        let bodyPadding = padding (fromIntegral fieldByteCount + 16) 8
        getBytes . fromIntegral $ bodyPadding

        let bodySig = findBodySignature fields

        bodyBytes <- getBytes bodyLength
        body <- unmarshal' bodySig bodyBytes

        y <- case buildReceivedMessage typeCode fields of
                Right x -> return x
                Left x -> E.throwError $ MissingHeaderField x

        return $ y serial flags body


findBodySignature :: [M.HeaderField] -> T.Signature
findBodySignature fields = fromMaybe empty signature where
        empty = T.mkSignature' ""
        signature = listToMaybe [x | M.Signature x <- fields]

buildReceivedMessage :: Word8 -> [M.HeaderField] -> Either Text 
                        (M.Serial -> (Set.Set M.Flag) -> [T.Variant]
                         -> M.ReceivedMessage)

buildReceivedMessage 1 fields = do
        path <- require "path" [x | M.Path x <- fields]
        member <- require "member name" [x | M.Member x <- fields]
        return $ \serial flags body -> let
                iface = listToMaybe [x | M.Interface x <- fields]
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.MethodCall path member iface dest flags body
                in M.ReceivedMethodCall serial sender msg

buildReceivedMessage 2 fields = do
        replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
        return $ \serial flags body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.MethodReturn replySerial dest flags body
                in M.ReceivedMethodReturn serial sender msg

buildReceivedMessage 3 fields = do
        name <- require "error name" [x | M.ErrorName x <- fields]
        replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
        return $ \serial flags body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.Error name replySerial dest flags body
                in M.ReceivedError serial sender msg

buildReceivedMessage 4 fields = do
        path <- require "path" [x | M.Path x <- fields]
        member <- require "member name" [x | M.Member x <- fields]
        iface <- require "interface" [x | M.Interface x <- fields]
        return $ \serial flags body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.Signal path member iface dest flags body
                in M.ReceivedSignal serial sender msg

buildReceivedMessage typeCode fields = return $ \serial flags body -> let
        sender = listToMaybe [x | M.Sender x <- fields]
        msg = M.Unknown typeCode flags body
        in M.ReceivedUnknown serial sender msg

require :: Text -> [a] -> Either Text a
require _     (x:_) = Right x
require label _     = Left label

instance E.Error Text where
        strMsg = TL.pack