-- Copyright (C) 2009-2010 John Millikin <jmillikin@gmail.com>
-- 
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
-- 
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
-- 
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE TypeFamilies #-}
module DBus.Wire.Marshal where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as L
import qualified Data.Binary.Builder as B
import Data.Binary.Put (runPut)
import qualified Data.Binary.IEEE754 as IEEE
import DBus.Wire.Unicode (maybeEncodeUtf8)
import qualified DBus.Constants as C
import qualified DBus.Message.Internal as M
import Data.Bits ((.|.))
import qualified Data.Set as Set
import DBus.Wire.Internal
import Control.Monad (when)
import Data.Maybe (fromJust)
import Data.Word (Word8, Word32, Word64)
import Data.Int (Int64)

import qualified DBus.Types as T
import qualified DBus.Types.Internal as T
data MarshalState = MarshalState {-# UNPACK #-} !B.Builder {-# UNPACK #-} !Word64

data MarshalR a = MarshalRL MarshalError | MarshalRR a {-# UNPACK #-} !MarshalState

type Marshal = MarshalM ()
newtype MarshalM a = MarshalM { unMarshalM :: Endianness -> MarshalState -> MarshalR a }

instance Monad MarshalM where
	{-# INLINE return #-}
	return a = MarshalM $ \_ s -> MarshalRR a s
	
	{-# INLINE (>>=) #-}
	m >>= k = MarshalM $ \e s -> case unMarshalM m e s of
		MarshalRL err -> MarshalRL err
		MarshalRR a s' -> unMarshalM (k a) e s'
	
	{-# INLINE (>>) #-}
	m >> k = MarshalM $ \e s -> case unMarshalM m e s of
		MarshalRL err -> MarshalRL err
		MarshalRR _ s' -> unMarshalM k e s'

throwError :: MarshalError -> MarshalM a
throwError err = MarshalM $ \_ _ -> MarshalRL err

{-# INLINE getState #-}
getState :: MarshalM MarshalState
getState = MarshalM $ \_ s -> MarshalRR s s

{-# INLINE putState #-}
putState :: MarshalState -> MarshalM ()
putState s = MarshalM $ \_ _ -> MarshalRR () s
runMarshal :: Marshal -> Endianness -> Either MarshalError L.ByteString
runMarshal m e = case unMarshalM m e (MarshalState B.empty 0) of
	MarshalRL err -> Left err
	MarshalRR _ (MarshalState builder _) -> Right $ B.toLazyByteString builder
marshal :: T.Variant -> Marshal
marshal v = case v of
	T.VarBoxWord8  x -> marshalWord8 x
	T.VarBoxWord16 x -> marshalBuilder 2 B.putWord16be B.putWord16le x
	T.VarBoxWord32 x -> marshalWord32 x
	T.VarBoxWord64 x -> marshalBuilder 8 B.putWord64be B.putWord64le x
	T.VarBoxInt16  x -> marshalBuilder 2 B.putWord16be B.putWord16le $ fromIntegral x
	T.VarBoxInt32  x -> marshalBuilder 4 B.putWord32be B.putWord32le $ fromIntegral x
	T.VarBoxInt64  x -> marshalBuilder 8 B.putWord64be B.putWord64le $ fromIntegral x
	T.VarBoxDouble x -> marshalDouble x
	T.VarBoxBool x -> marshalWord32 $ if x then 1 else 0
	T.VarBoxString x -> marshalText x
	T.VarBoxObjectPath x -> marshalText . T.strObjectPath $ x
	T.VarBoxSignature x -> marshalSignature x
	T.VarBoxArray x -> marshalArray x
	T.VarBoxDictionary x -> marshalArray (T.dictionaryToArray x)
	T.VarBoxStructure (T.Structure vs) -> do
		pad 8
		mapM_ marshal vs
	T.VarBoxVariant x -> do
		let textSig = T.typeCode . T.variantType $ x
		sig <- case T.variantSignature x of
			Just x' -> return x'
			Nothing -> throwError $ InvalidVariantSignature textSig
		marshalSignature sig
		marshal x
appendS :: BS.ByteString -> Marshal
appendS bytes = MarshalM $ \_ (MarshalState builder count) -> let
	builder' = B.append builder $ B.fromByteString bytes
	count' = count + fromIntegral (BS.length bytes)
	in MarshalRR () (MarshalState builder' count')
appendL :: L.ByteString -> Marshal
appendL bytes = MarshalM $ \_ (MarshalState builder count) -> let
	builder' = B.append builder $ B.fromLazyByteString bytes
	count' = count + fromIntegral (L.length bytes)
	in MarshalRR () (MarshalState builder' count')
pad :: Word8 -> Marshal
pad count = MarshalM $ \e s@(MarshalState _ existing) -> let
	padding' = fromIntegral $ padding existing count
	bytes = BS.replicate padding' 0
	in unMarshalM (appendS bytes) e s
marshalBuilder :: Word8 -> (a -> B.Builder) -> (a -> B.Builder) -> a -> Marshal
marshalBuilder size be le x = do
	pad size
	MarshalM $ \e (MarshalState builder count) -> let
		builder' = B.append builder $ case e of
			BigEndian -> be x
			LittleEndian -> le x
		size' = fromIntegral size
		in MarshalRR () (MarshalState builder' (count + size'))
data MarshalError
	= MessageTooLong Word64
	| ArrayTooLong Word64
	| InvalidBodySignature Text
	| InvalidVariantSignature Text
	| InvalidText Text
	deriving (Eq)

instance Show MarshalError where
	show (MessageTooLong x) = concat
		["Message too long (", show x, " bytes)."]
	show (ArrayTooLong x) = concat
		["Array too long (", show x, " bytes)."]
	show (InvalidBodySignature x) = concat
		["Invalid body signature: ", show x]
	show (InvalidVariantSignature x) = concat
		["Invalid variant signature: ", show x]
	show (InvalidText x) = concat
		["Text cannot be marshaled: ", show x]
marshalWord32 :: Word32 -> Marshal
marshalWord32 = marshalBuilder 4 B.putWord32be B.putWord32le
{-# INLINE marshalWord8 #-}
marshalWord8 :: Word8 -> Marshal
marshalWord8 x = MarshalM $ \_ (MarshalState builder count) -> let
	builder' = B.append builder $ B.singleton x
	in MarshalRR () (MarshalState builder' (count + 1))
marshalDouble :: Double -> Marshal
marshalDouble x = do
	pad 8
	MarshalM $ \e s -> let
		put = case e of
			BigEndian -> IEEE.putFloat64be
			LittleEndian -> IEEE.putFloat64le
		bytes = runPut $ put x
		in unMarshalM (appendL bytes) e s
marshalText :: Text -> Marshal
marshalText x = do
	bytes <- case maybeEncodeUtf8 x of
		Just x' -> return x'
		Nothing -> throwError $ InvalidText x
	when (L.any (== 0) bytes) $
		throwError $ InvalidText x
	marshalWord32 . fromIntegral . L.length $ bytes
	appendL bytes
	marshalWord8 0
marshalSignature :: T.Signature -> Marshal
marshalSignature x = do
	let bytes = T.bytesSignature x
	let size = fromIntegral . BS.length $ bytes
	marshalWord8 size
	appendS bytes
	marshalWord8 0
marshalArray :: T.Array -> Marshal
marshalArray x = do
	(arrayPadding, arrayBytes) <- getArrayBytes (T.arrayType x) x
	let arrayLen = L.length arrayBytes
	when (arrayLen > fromIntegral C.arrayMaximumLength)
		(throwError $ ArrayTooLong $ fromIntegral arrayLen)
	marshalWord32 $ fromIntegral arrayLen
	appendL $ L.replicate arrayPadding 0
	appendL arrayBytes
getArrayBytes :: T.Type -> T.Array -> MarshalM (Int64, L.ByteString)
getArrayBytes T.DBusByte x = return (0, bytes) where
	Just bytes = T.arrayToBytes x
getArrayBytes itemType x = do
	let vs = T.arrayItems x
	s <- getState
	(MarshalState _ afterLength) <- marshalWord32 0 >> getState
	(MarshalState _ afterPadding) <- pad (alignment itemType) >> getState
	
	putState $ MarshalState B.empty afterPadding
	(MarshalState itemBuilder _) <- mapM_ marshal vs >> getState
	
	let itemBytes = B.toLazyByteString itemBuilder
	    paddingSize = fromIntegral $ afterPadding - afterLength
	
	putState s
	return (paddingSize, itemBytes)
encodeFlags :: Set.Set M.Flag -> Word8
encodeFlags flags = foldr (.|.) 0 $ map flagValue $ Set.toList flags where
	flagValue M.NoReplyExpected = 0x1
	flagValue M.NoAutoStart     = 0x2
encodeField :: M.HeaderField -> T.Structure
encodeField (M.Path x)        = encodeField' 1 x
encodeField (M.Interface x)   = encodeField' 2 x
encodeField (M.Member x)      = encodeField' 3 x
encodeField (M.ErrorName x)   = encodeField' 4 x
encodeField (M.ReplySerial x) = encodeField' 5 x
encodeField (M.Destination x) = encodeField' 6 x
encodeField (M.Sender x)      = encodeField' 7 x
encodeField (M.Signature x)   = encodeField' 8 x

encodeField' :: T.Variable a => Word8 -> a -> T.Structure
encodeField' code x = T.Structure
	[ T.toVariant code
	, T.toVariant $ T.toVariant x
	]
-- | Convert a 'M.Message' into a 'L.ByteString'. Although unusual, it is
-- possible for marshaling to fail -- if this occurs, an appropriate error
-- will be returned instead.
marshalMessage :: M.Message a => Endianness -> M.Serial -> a
               -> Either MarshalError L.ByteString
marshalMessage e serial msg = runMarshal marshaler e where
	body = M.messageBody msg
	marshaler = do
		sig <- checkBodySig body
		empty <- getState
		mapM_ marshal body
		(MarshalState bodyBytesB _) <- getState
		putState empty
		marshalEndianness e
		let bodyBytes = B.toLazyByteString bodyBytesB
		marshalHeader msg serial sig
			$ fromIntegral . L.length $ bodyBytes
		pad 8
		appendL bodyBytes
		checkMaximumSize
checkBodySig :: [T.Variant] -> MarshalM T.Signature
checkBodySig vs = let
	textSig = TL.concat . map (T.typeCode . T.variantType) $ vs
	bytesSig = BS.concat . map (T.typeCodeB . T.variantType) $ vs
	invalid = throwError $ InvalidBodySignature textSig
	in case T.mkBytesSignature bytesSig of
		Just x -> return x
		Nothing -> invalid
marshalHeader :: M.Message a => a -> M.Serial -> T.Signature -> Word32
              -> Marshal
marshalHeader msg serial bodySig bodyLength = do
	let fields = M.Signature bodySig : M.messageHeaderFields msg
	marshalWord8 . M.messageTypeCode $ msg
	marshalWord8 . encodeFlags . M.messageFlags $ msg
	marshalWord8 C.protocolVersion
	marshalWord32 bodyLength
	marshalWord32 . M.serialValue $ serial
	let fieldType = T.DBusStructure [T.DBusByte, T.DBusVariant]
	marshalArray . fromJust . T.toArray fieldType
	        $ map encodeField fields
marshalEndianness :: Endianness -> Marshal
marshalEndianness = marshal . T.toVariant . encodeEndianness
checkMaximumSize :: Marshal
checkMaximumSize = do
	(MarshalState _ messageLength) <- getState
	when (messageLength > fromIntegral C.messageMaximumLength)
		(throwError $ MessageTooLong $ fromIntegral messageLength)