{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE TypeFamilies #-}
module DBus.Wire.Unmarshal where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Monad.State as State
import Control.Monad.Trans.Class (lift)
import qualified DBus.Util.MonadError as E
import qualified Data.ByteString.Lazy as L

import qualified Data.Binary.Get as G

import qualified Data.Binary.IEEE754 as IEEE

import DBus.Wire.Unicode (maybeDecodeUtf8)

import qualified DBus.Message.Internal as M

import Data.Bits ((.&.))
import qualified Data.Set as Set

import qualified DBus.Constants as C

import Control.Monad (when, unless, liftM)
import Data.Maybe (fromJust, listToMaybe, fromMaybe)
import Data.Word (Word8, Word32, Word64)
import Data.Int (Int16, Int32, Int64)

import DBus.Wire.Internal
import qualified DBus.Types as T

data UnmarshalState = UnmarshalState Endianness L.ByteString !Word64
type Unmarshal = E.ErrorT UnmarshalError (State.State UnmarshalState)

runUnmarshal :: Unmarshal a -> Endianness -> L.ByteString -> Either UnmarshalError a
runUnmarshal m e bytes = State.evalState (E.runErrorT m) state where
        state = UnmarshalState e bytes 0

unmarshal :: T.Signature -> Unmarshal [T.Variant]
unmarshal = mapM unmarshalType . T.signatureTypes

unmarshalType :: T.Type -> Unmarshal T.Variant
unmarshalType T.DBusByte = fmap (T.toVariant . L.head) $ consume 1
unmarshalType T.DBusWord16 = unmarshalGet' 2 G.getWord16be G.getWord16le
unmarshalType T.DBusWord32 = unmarshalGet' 4 G.getWord32be G.getWord32le
unmarshalType T.DBusWord64 = unmarshalGet' 8 G.getWord64be G.getWord64le

unmarshalType T.DBusInt16  = do
        x <- unmarshalGet 2 G.getWord16be G.getWord16le
        return . T.toVariant $ (fromIntegral x :: Int16)

unmarshalType T.DBusInt32  = do
        x <- unmarshalGet 4 G.getWord32be G.getWord32le
        return . T.toVariant $ (fromIntegral x :: Int32)

unmarshalType T.DBusInt64  = do
        x <- unmarshalGet 8 G.getWord64be G.getWord64le
        return . T.toVariant $ (fromIntegral x :: Int64)

unmarshalType T.DBusDouble = unmarshalGet' 8 IEEE.getFloat64be IEEE.getFloat64le

unmarshalType T.DBusBoolean = unmarshalWord32 >>=
        fromMaybeU' "boolean" (\x -> case x of
                0 -> Just False
                1 -> Just True
                _ -> Nothing)

unmarshalType T.DBusString = fmap T.toVariant unmarshalText

unmarshalType T.DBusObjectPath = unmarshalText >>=
        fromMaybeU' "object path" T.mkObjectPath

unmarshalType T.DBusSignature = fmap T.toVariant unmarshalSignature

unmarshalType (T.DBusArray t) = T.toVariant `fmap` unmarshalArray t

unmarshalType (T.DBusDictionary kt vt) = do
        let pairType = T.DBusStructure [kt, vt]
        array <- unmarshalArray pairType
        fromMaybeU' "dictionary" T.arrayToDictionary array

unmarshalType (T.DBusStructure ts) = do
        skipPadding 8
        fmap (T.toVariant . T.Structure) $ mapM unmarshalType ts

unmarshalType T.DBusVariant = do
        let getType sig = case T.signatureTypes sig of
                [t] -> Just t
                _   -> Nothing
        
        t <- fromMaybeU "variant signature" getType =<< unmarshalSignature
        T.toVariant `fmap` unmarshalType t


consume :: Word64 -> Unmarshal L.ByteString
consume count = do
        (UnmarshalState e bytes offset) <- State.get
        let (x, bytes') = L.splitAt (fromIntegral count) bytes
        unless (L.length x == fromIntegral count) $
                E.throwError $ UnexpectedEOF offset
        
        State.put $ UnmarshalState e bytes' (offset + count)
        return x

skipPadding :: Word8 -> Unmarshal ()
skipPadding count = do
        (UnmarshalState _ _ offset) <- State.get
        bytes <- consume $ padding offset count
        unless (L.all (== 0) bytes) $
                E.throwError $ InvalidPadding offset

skipTerminator :: Unmarshal ()
skipTerminator = do
        (UnmarshalState _ _ offset) <- State.get
        bytes <- consume 1
        unless (L.all (== 0) bytes) $
                E.throwError $ MissingTerminator offset

fromMaybeU :: Show a => Text -> (a -> Maybe b) -> a -> Unmarshal b
fromMaybeU label f x = case f x of
        Just x' -> return x'
        Nothing -> E.throwError . Invalid label . TL.pack . show $ x

fromMaybeU' :: (Show a, T.Variable b) => Text -> (a -> Maybe b) -> a
           -> Unmarshal T.Variant
fromMaybeU' label f x = do
        x' <- fromMaybeU label f x
        return $ T.toVariant x'

unmarshalGet :: Word8 -> G.Get a -> G.Get a -> Unmarshal a
unmarshalGet count be le = do
        skipPadding count
        (UnmarshalState e _ _) <- State.get
        bs <- consume . fromIntegral $ count
        let get' = case e of
                BigEndian -> be
                LittleEndian -> le
        return $ G.runGet get' bs

unmarshalGet' :: T.Variable a => Word8 -> G.Get a -> G.Get a
              -> Unmarshal T.Variant
unmarshalGet' count be le = T.toVariant `fmap` unmarshalGet count be le

untilM :: Monad m => m Bool -> m a -> m [a]
untilM test comp = do
        done <- test
        if done
                then return []
                else do
                        x <- comp
                        xs <- untilM test comp
                        return $ x:xs

data UnmarshalError
        = UnsupportedProtocolVersion Word8
        | UnexpectedEOF Word64
        | Invalid Text Text
        | MissingHeaderField Text
        | InvalidHeaderField Text T.Variant
        | InvalidPadding Word64
        | MissingTerminator Word64
        | ArraySizeMismatch
        deriving (Eq)

instance Show UnmarshalError where
        show (UnsupportedProtocolVersion x) = concat
                ["Unsupported protocol version: ", show x]
        show (UnexpectedEOF pos) = concat
                ["Unexpected EOF at position ", show pos]
        show (Invalid label x) = TL.unpack $ TL.concat
                ["Invalid ", label, ": ", x]
        show (MissingHeaderField x) = concat
                ["Required field " , show x , " is missing."]
        show (InvalidHeaderField x got) = concat
                [ "Invalid header field ", show x, ": ", show got]
        show (InvalidPadding pos) = concat
                ["Invalid padding at position ", show pos]
        show (MissingTerminator pos) = concat
                ["Missing NUL terminator at position ", show pos]
        show ArraySizeMismatch = "Array size mismatch"

unmarshalWord32 :: Unmarshal Word32
unmarshalWord32 = unmarshalGet 4 G.getWord32be G.getWord32le

unmarshalText :: Unmarshal Text
unmarshalText = do
        byteCount <- unmarshalWord32
        bytes <- consume . fromIntegral $ byteCount
        skipTerminator
        fromMaybeU "text" maybeDecodeUtf8 bytes

unmarshalSignature :: Unmarshal T.Signature
unmarshalSignature = do
        byteCount <- L.head `fmap` consume 1
        bytes <- consume $ fromIntegral byteCount
        sigText <- fromMaybeU "text" maybeDecodeUtf8 bytes
        skipTerminator
        fromMaybeU "signature" T.mkSignature sigText

unmarshalArray :: T.Type -> Unmarshal T.Array
unmarshalArray T.DBusByte = do
        byteCount <- unmarshalWord32
        T.arrayFromBytes `fmap` consume (fromIntegral byteCount)

unmarshalArray itemType = do
        let getOffset = do
                (UnmarshalState _ _ o) <- State.get
                return o
        byteCount <- unmarshalWord32
        skipPadding (alignment itemType)
        start <- getOffset
        let end = start + fromIntegral byteCount
        vs <- untilM (fmap (>= end) getOffset) (unmarshalType itemType)
        end' <- getOffset
        when (end' > end) $
                E.throwError ArraySizeMismatch
        fromMaybeU "array" (T.arrayFromItems itemType) vs

decodeFlags :: Word8 -> Set.Set M.Flag
decodeFlags word = Set.fromList flags where
        flagSet = [ (0x1, M.NoReplyExpected)
                  , (0x2, M.NoAutoStart)
                  ]
        flags = flagSet >>= \(x, y) -> [y | word .&. x > 0]

decodeField :: Monad m => T.Structure
            -> E.ErrorT UnmarshalError m [M.HeaderField]
decodeField struct = case unpackField struct of
        (1, x) -> decodeField' x M.Path "path"
        (2, x) -> decodeField' x M.Interface "interface"
        (3, x) -> decodeField' x M.Member "member"
        (4, x) -> decodeField' x M.ErrorName "error name"
        (5, x) -> decodeField' x M.ReplySerial "reply serial"
        (6, x) -> decodeField' x M.Destination "destination"
        (7, x) -> decodeField' x M.Sender "sender"
        (8, x) -> decodeField' x M.Signature "signature"
        _      -> return []

decodeField' :: (Monad m, T.Variable a) => T.Variant -> (a -> b) -> Text
             -> E.ErrorT UnmarshalError m [b]
decodeField' x f label = case T.fromVariant x of
        Just x' -> return [f x']
        Nothing -> E.throwError $ InvalidHeaderField label x

unpackField :: T.Structure -> (Word8, T.Variant)
unpackField struct = (c', v') where
        T.Structure [c, v] = struct
        c' = fromJust . T.fromVariant $ c
        v' = fromJust . T.fromVariant $ v

-- | Read bytes from a monad until a complete message has been received.

unmarshalMessage :: Monad m => (Word32 -> m L.ByteString)
                 -> m (Either UnmarshalError M.ReceivedMessage)
unmarshalMessage getBytes' = E.runErrorT $ do
        let getBytes = lift . getBytes'
        
        let fixedSig = "yyyyuuu"
        fixedBytes <- getBytes 16

        let messageVersion = L.index fixedBytes 3
        when (messageVersion /= C.protocolVersion) $
                E.throwError $ UnsupportedProtocolVersion messageVersion

        let eByte = L.index fixedBytes 0
        endianness <- case decodeEndianness eByte of
                Just x' -> return x'
                Nothing -> E.throwError . Invalid "endianness" . TL.pack . show $ eByte

        let unmarshal' x bytes = case runUnmarshal (unmarshal x) endianness bytes of
                Right x' -> return x'
                Left  e  -> E.throwError e
        fixed <- unmarshal' fixedSig fixedBytes
        let typeCode = fromJust . T.fromVariant $ fixed !! 1
        let flags = decodeFlags . fromJust . T.fromVariant $ fixed !! 2
        let bodyLength = fromJust . T.fromVariant $ fixed !! 4
        let serial = fromJust . T.fromVariant $ fixed !! 5

        let fieldByteCount = fromJust . T.fromVariant $ fixed !! 6

        let headerSig  = "yyyyuua(yv)"
        fieldBytes <- getBytes fieldByteCount
        let headerBytes = L.append fixedBytes fieldBytes
        header <- unmarshal' headerSig headerBytes

        let fieldArray = fromJust . T.fromVariant $ header !! 6
        let fieldStructures = fromJust . T.fromArray $ fieldArray
        fields <- concat `liftM` mapM decodeField fieldStructures

        let bodyPadding = padding (fromIntegral fieldByteCount + 16) 8
        getBytes . fromIntegral $ bodyPadding

        let bodySig = findBodySignature fields

        bodyBytes <- getBytes bodyLength
        body <- unmarshal' bodySig bodyBytes

        y <- case buildReceivedMessage typeCode fields of
                EitherM (Right x) -> return x
                EitherM (Left x) -> E.throwError $ MissingHeaderField x

        return $ y serial flags body


findBodySignature :: [M.HeaderField] -> T.Signature
findBodySignature fields = fromMaybe "" signature where
        signature = listToMaybe [x | M.Signature x <- fields]

newtype EitherM a b = EitherM (Either a b)

instance Monad (EitherM a) where
        return = EitherM . Right
        (EitherM (Left x)) >>= _ = EitherM (Left x)
        (EitherM (Right x)) >>= k = k x

buildReceivedMessage :: Word8 -> [M.HeaderField] -> EitherM Text
                        (M.Serial -> (Set.Set M.Flag) -> [T.Variant]
                         -> M.ReceivedMessage)

buildReceivedMessage 1 fields = do
        path <- require "path" [x | M.Path x <- fields]
        member <- require "member name" [x | M.Member x <- fields]
        return $ \serial flags body -> let
                iface = listToMaybe [x | M.Interface x <- fields]
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.MethodCall path member iface dest flags body
                in M.ReceivedMethodCall serial sender msg

buildReceivedMessage 2 fields = do
        replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
        return $ \serial _ body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.MethodReturn replySerial dest body
                in M.ReceivedMethodReturn serial sender msg

buildReceivedMessage 3 fields = do
        name <- require "error name" [x | M.ErrorName x <- fields]
        replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
        return $ \serial _ body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.Error name replySerial dest body
                in M.ReceivedError serial sender msg

buildReceivedMessage 4 fields = do
        path <- require "path" [x | M.Path x <- fields]
        member <- require "member name" [x | M.Member x <- fields]
        iface <- require "interface" [x | M.Interface x <- fields]
        return $ \serial _ body -> let
                dest = listToMaybe [x | M.Destination x <- fields]
                sender = listToMaybe [x | M.Sender x <- fields]
                msg = M.Signal path member iface dest body
                in M.ReceivedSignal serial sender msg

buildReceivedMessage typeCode fields = return $ \serial flags body -> let
        sender = listToMaybe [x | M.Sender x <- fields]
        msg = M.Unknown typeCode flags body
        in M.ReceivedUnknown serial sender msg

require :: Text -> [a] -> EitherM Text a
require _     (x:_) = return x
require label _     = EitherM $ Left label