-- Copyright (C) 2009-2010 John Millikin <jmillikin@gmail.com>
-- 
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
-- 
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
-- 
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module DBus.Wire.Unmarshal where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL
import qualified Control.Monad.State as State
import Control.Monad.Trans.Class (lift)
import qualified DBus.Util.MonadError as E
import qualified Data.ByteString.Lazy as L
import qualified Data.Binary.Get as G
import qualified Data.Binary.IEEE754 as IEEE
import DBus.Wire.Unicode (maybeDecodeUtf8)
import qualified DBus.Message.Internal as M
import Data.Bits ((.&.))
import qualified Data.Set as Set
import qualified DBus.Constants as C
import Control.Monad (when, unless, liftM)
import Data.Maybe (fromJust, listToMaybe, fromMaybe)
import Data.Word (Word8, Word32, Word64)
import Data.Int (Int16, Int32, Int64)

import DBus.Wire.Internal
import qualified DBus.Types as T
data UnmarshalState = UnmarshalState Endianness L.ByteString !Word64
type Unmarshal = E.ErrorT UnmarshalError (State.State UnmarshalState)
runUnmarshal :: Unmarshal a -> Endianness -> L.ByteString -> Either UnmarshalError a
runUnmarshal m e bytes = State.evalState (E.runErrorT m) state where
	state = UnmarshalState e bytes 0
unmarshal :: T.Signature -> Unmarshal [T.Variant]
unmarshal = mapM unmarshalType . T.signatureTypes

unmarshalType :: T.Type -> Unmarshal T.Variant
unmarshalType T.DBusByte = fmap (T.toVariant . L.head) $ consume 1
unmarshalType T.DBusWord16 = unmarshalGet' 2 G.getWord16be G.getWord16le
unmarshalType T.DBusWord32 = unmarshalGet' 4 G.getWord32be G.getWord32le
unmarshalType T.DBusWord64 = unmarshalGet' 8 G.getWord64be G.getWord64le

unmarshalType T.DBusInt16  = do
	x <- unmarshalGet 2 G.getWord16be G.getWord16le
	return . T.toVariant $ (fromIntegral x :: Int16)

unmarshalType T.DBusInt32  = do
	x <- unmarshalGet 4 G.getWord32be G.getWord32le
	return . T.toVariant $ (fromIntegral x :: Int32)

unmarshalType T.DBusInt64  = do
	x <- unmarshalGet 8 G.getWord64be G.getWord64le
	return . T.toVariant $ (fromIntegral x :: Int64)
unmarshalType T.DBusDouble = unmarshalGet' 8 IEEE.getFloat64be IEEE.getFloat64le
unmarshalType T.DBusBoolean = unmarshalWord32 >>=
	fromMaybeU' "boolean" (\x -> case x of
		0 -> Just False
		1 -> Just True
		_ -> Nothing)
unmarshalType T.DBusString = fmap T.toVariant unmarshalText

unmarshalType T.DBusObjectPath = unmarshalText >>=
	fromMaybeU' "object path" T.mkObjectPath
unmarshalType T.DBusSignature = fmap T.toVariant unmarshalSignature
unmarshalType (T.DBusArray t) = T.toVariant `fmap` unmarshalArray t
unmarshalType (T.DBusDictionary kt vt) = do
	let pairType = T.DBusStructure [kt, vt]
	array <- unmarshalArray pairType
	fromMaybeU' "dictionary" T.arrayToDictionary array
unmarshalType (T.DBusStructure ts) = do
	skipPadding 8
	fmap (T.toVariant . T.Structure) $ mapM unmarshalType ts
unmarshalType T.DBusVariant = do
	let getType sig = case T.signatureTypes sig of
		[t] -> Just t
		_   -> Nothing
	
	t <- fromMaybeU "variant signature" getType =<< unmarshalSignature
	T.toVariant `fmap` unmarshalType t
consume :: Word64 -> Unmarshal L.ByteString
consume count = do
	(UnmarshalState e bytes offset) <- State.get
	let (x, bytes') = L.splitAt (fromIntegral count) bytes
	unless (L.length x == fromIntegral count) $
		E.throwError $ UnexpectedEOF offset
	
	State.put $ UnmarshalState e bytes' (offset + count)
	return x
skipPadding :: Word8 -> Unmarshal ()
skipPadding count = do
	(UnmarshalState _ _ offset) <- State.get
	bytes <- consume $ padding offset count
	unless (L.all (== 0) bytes) $
		E.throwError $ InvalidPadding offset
skipTerminator :: Unmarshal ()
skipTerminator = do
	(UnmarshalState _ _ offset) <- State.get
	bytes <- consume 1
	unless (L.all (== 0) bytes) $
		E.throwError $ MissingTerminator offset
fromMaybeU :: Show a => Text -> (a -> Maybe b) -> a -> Unmarshal b
fromMaybeU label f x = case f x of
	Just x' -> return x'
	Nothing -> E.throwError . Invalid label . TL.pack . show $ x

fromMaybeU' :: (Show a, T.Variable b) => Text -> (a -> Maybe b) -> a
           -> Unmarshal T.Variant
fromMaybeU' label f x = do
	x' <- fromMaybeU label f x
	return $ T.toVariant x'
unmarshalGet :: Word8 -> G.Get a -> G.Get a -> Unmarshal a
unmarshalGet count be le = do
	skipPadding count
	(UnmarshalState e _ _) <- State.get
	bs <- consume . fromIntegral $ count
	let get' = case e of
		BigEndian -> be
		LittleEndian -> le
	return $ G.runGet get' bs

unmarshalGet' :: T.Variable a => Word8 -> G.Get a -> G.Get a
              -> Unmarshal T.Variant
unmarshalGet' count be le = T.toVariant `fmap` unmarshalGet count be le
untilM :: Monad m => m Bool -> m a -> m [a]
untilM test comp = do
	done <- test
	if done
		then return []
		else do
			x <- comp
			xs <- untilM test comp
			return $ x:xs
data UnmarshalError
	= UnsupportedProtocolVersion Word8
	| UnexpectedEOF Word64
	| Invalid Text Text
	| MissingHeaderField Text
	| InvalidHeaderField Text T.Variant
	| InvalidPadding Word64
	| MissingTerminator Word64
	| ArraySizeMismatch
	deriving (Eq)

instance Show UnmarshalError where
	show (UnsupportedProtocolVersion x) = concat
		["Unsupported protocol version: ", show x]
	show (UnexpectedEOF pos) = concat
		["Unexpected EOF at position ", show pos]
	show (Invalid label x) = TL.unpack $ TL.concat
		["Invalid ", label, ": ", x]
	show (MissingHeaderField x) = concat
		["Required field " , show x , " is missing."]
	show (InvalidHeaderField x got) = concat
		[ "Invalid header field ", show x, ": ", show got]
	show (InvalidPadding pos) = concat
		["Invalid padding at position ", show pos]
	show (MissingTerminator pos) = concat
		["Missing NUL terminator at position ", show pos]
	show ArraySizeMismatch = "Array size mismatch"
unmarshalWord32 :: Unmarshal Word32
unmarshalWord32 = unmarshalGet 4 G.getWord32be G.getWord32le
unmarshalText :: Unmarshal Text
unmarshalText = do
	byteCount <- unmarshalWord32
	bytes <- consume . fromIntegral $ byteCount
	skipTerminator
	fromMaybeU "text" maybeDecodeUtf8 bytes
unmarshalSignature :: Unmarshal T.Signature
unmarshalSignature = do
	byteCount <- L.head `fmap` consume 1
	bytes <- consume $ fromIntegral byteCount
	sigText <- fromMaybeU "text" maybeDecodeUtf8 bytes
	skipTerminator
	fromMaybeU "signature" T.mkSignature sigText
unmarshalArray :: T.Type -> Unmarshal T.Array
unmarshalArray T.DBusByte = do
	byteCount <- unmarshalWord32
	T.arrayFromBytes `fmap` consume (fromIntegral byteCount)
unmarshalArray itemType = do
	let getOffset = do
		(UnmarshalState _ _ o) <- State.get
		return o
	byteCount <- unmarshalWord32
	skipPadding (alignment itemType)
	start <- getOffset
	let end = start + fromIntegral byteCount
	vs <- untilM (fmap (>= end) getOffset) (unmarshalType itemType)
	end' <- getOffset
	when (end' > end) $
		E.throwError ArraySizeMismatch
	fromMaybeU "array" (T.arrayFromItems itemType) vs
decodeFlags :: Word8 -> Set.Set M.Flag
decodeFlags word = Set.fromList flags where
	flagSet = [ (0x1, M.NoReplyExpected)
	          , (0x2, M.NoAutoStart)
	          ]
	flags = flagSet >>= \(x, y) -> [y | word .&. x > 0]
decodeField :: Monad m => T.Structure
            -> E.ErrorT UnmarshalError m [M.HeaderField]
decodeField struct = case unpackField struct of
	(1, x) -> decodeField' x M.Path "path"
	(2, x) -> decodeField' x M.Interface "interface"
	(3, x) -> decodeField' x M.Member "member"
	(4, x) -> decodeField' x M.ErrorName "error name"
	(5, x) -> decodeField' x M.ReplySerial "reply serial"
	(6, x) -> decodeField' x M.Destination "destination"
	(7, x) -> decodeField' x M.Sender "sender"
	(8, x) -> decodeField' x M.Signature "signature"
	_      -> return []

decodeField' :: (Monad m, T.Variable a) => T.Variant -> (a -> b) -> Text
             -> E.ErrorT UnmarshalError m [b]
decodeField' x f label = case T.fromVariant x of
	Just x' -> return [f x']
	Nothing -> E.throwError $ InvalidHeaderField label x
unpackField :: T.Structure -> (Word8, T.Variant)
unpackField struct = (c', v') where
	T.Structure [c, v] = struct
	c' = fromJust . T.fromVariant $ c
	v' = fromJust . T.fromVariant $ v
-- | Read bytes from a monad until a complete message has been received.
unmarshalMessage :: Monad m => (Word32 -> m L.ByteString)
                 -> m (Either UnmarshalError M.ReceivedMessage)
unmarshalMessage getBytes' = E.runErrorT $ do
	let getBytes = lift . getBytes'
	
	let fixedSig = "yyyyuuu"
	fixedBytes <- getBytes 16
	let messageVersion = L.index fixedBytes 3
	when (messageVersion /= C.protocolVersion) $
		E.throwError $ UnsupportedProtocolVersion messageVersion
	let eByte = L.index fixedBytes 0
	endianness <- case decodeEndianness eByte of
		Just x' -> return x'
		Nothing -> E.throwError . Invalid "endianness" . TL.pack . show $ eByte
	let unmarshal' x bytes = case runUnmarshal (unmarshal x) endianness bytes of
		Right x' -> return x'
		Left  e  -> E.throwError e
	fixed <- unmarshal' fixedSig fixedBytes
	let typeCode = fromJust . T.fromVariant $ fixed !! 1
	let flags = decodeFlags . fromJust . T.fromVariant $ fixed !! 2
	let bodyLength = fromJust . T.fromVariant $ fixed !! 4
	let serial = fromJust . T.fromVariant $ fixed !! 5
	let fieldByteCount = fromJust . T.fromVariant $ fixed !! 6
	let headerSig  = "yyyyuua(yv)"
	fieldBytes <- getBytes fieldByteCount
	let headerBytes = L.append fixedBytes fieldBytes
	header <- unmarshal' headerSig headerBytes
	let fieldArray = fromJust . T.fromVariant $ header !! 6
	let fieldStructures = fromJust . T.fromArray $ fieldArray
	fields <- concat `liftM` mapM decodeField fieldStructures
	let bodyPadding = padding (fromIntegral fieldByteCount + 16) 8
	getBytes . fromIntegral $ bodyPadding
	let bodySig = findBodySignature fields
	bodyBytes <- getBytes bodyLength
	body <- unmarshal' bodySig bodyBytes
	y <- case buildReceivedMessage typeCode fields of
		EitherM (Right x) -> return x
		EitherM (Left x) -> E.throwError $ MissingHeaderField x
	return $ y serial flags body
findBodySignature :: [M.HeaderField] -> T.Signature
findBodySignature fields = fromMaybe "" signature where
	signature = listToMaybe [x | M.Signature x <- fields]
newtype EitherM a b = EitherM (Either a b)

instance Monad (EitherM a) where
	return = EitherM . Right
	(EitherM (Left x)) >>= _ = EitherM (Left x)
	(EitherM (Right x)) >>= k = k x
buildReceivedMessage :: Word8 -> [M.HeaderField] -> EitherM Text
                        (M.Serial -> (Set.Set M.Flag) -> [T.Variant]
                         -> M.ReceivedMessage)
buildReceivedMessage 1 fields = do
	path <- require "path" [x | M.Path x <- fields]
	member <- require "member name" [x | M.Member x <- fields]
	return $ \serial flags body -> let
		iface = listToMaybe [x | M.Interface x <- fields]
		dest = listToMaybe [x | M.Destination x <- fields]
		sender = listToMaybe [x | M.Sender x <- fields]
		msg = M.MethodCall path member iface dest flags body
		in M.ReceivedMethodCall serial sender msg
buildReceivedMessage 2 fields = do
	replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
	return $ \serial _ body -> let
		dest = listToMaybe [x | M.Destination x <- fields]
		sender = listToMaybe [x | M.Sender x <- fields]
		msg = M.MethodReturn replySerial dest body
		in M.ReceivedMethodReturn serial sender msg
buildReceivedMessage 3 fields = do
	name <- require "error name" [x | M.ErrorName x <- fields]
	replySerial <- require "reply serial" [x | M.ReplySerial x <- fields]
	return $ \serial _ body -> let
		dest = listToMaybe [x | M.Destination x <- fields]
		sender = listToMaybe [x | M.Sender x <- fields]
		msg = M.Error name replySerial dest body
		in M.ReceivedError serial sender msg
buildReceivedMessage 4 fields = do
	path <- require "path" [x | M.Path x <- fields]
	member <- require "member name" [x | M.Member x <- fields]
	iface <- require "interface" [x | M.Interface x <- fields]
	return $ \serial _ body -> let
		dest = listToMaybe [x | M.Destination x <- fields]
		sender = listToMaybe [x | M.Sender x <- fields]
		msg = M.Signal path member iface dest body
		in M.ReceivedSignal serial sender msg
buildReceivedMessage typeCode fields = return $ \serial flags body -> let
	sender = listToMaybe [x | M.Sender x <- fields]
	msg = M.Unknown typeCode flags body
	in M.ReceivedUnknown serial sender msg
require :: Text -> [a] -> EitherM Text a
require _     (x:_) = return x
require label _     = EitherM $ Left label