{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module DBus.Wire.Marshal where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Monad.State as State
import qualified Control.Monad.Error as E
import qualified Data.ByteString.Lazy as L
import qualified Data.Binary.Builder as B

import Data.Binary.Put (runPut)
import qualified Data.Binary.IEEE754 as IEEE

import DBus.Wire.Unicode (maybeEncodeUtf8)

import Data.Text.Lazy.Encoding (encodeUtf8)

import qualified DBus.Constants as C

import qualified DBus.Message.Internal as M

import Data.Bits ((.|.))
import qualified Data.Set as Set

import DBus.Wire.Internal
import Control.Monad (when)
import Data.Maybe (fromJust)
import Data.Word (Word8, Word32, Word64)
import Data.Int (Int16, Int32, Int64)

import qualified DBus.Types as T

data MarshalState = MarshalState Endianness B.Builder !Word64
newtype MarshalM a = MarshalM (E.ErrorT MarshalError (State.State MarshalState) a)
        deriving (Monad, E.MonadError MarshalError, State.MonadState MarshalState)
type Marshal = MarshalM ()

runMarshal :: Marshal -> Endianness -> Either MarshalError L.ByteString
runMarshal (MarshalM m) e = case State.runState (E.runErrorT m) initialState of
        (Right _, MarshalState _ builder _) -> Right (B.toLazyByteString builder)
        (Left  x, _) -> Left x
        where initialState = MarshalState e B.empty 0

marshal :: T.Variant -> Marshal
marshal v = marshalType (T.variantType v) where
        x :: T.Variable a => a
        x = fromJust . T.fromVariant $ v
        marshalType :: T.Type -> Marshal
        marshalType T.DBusByte   = append $ L.singleton x
        marshalType T.DBusWord16 = marshalBuilder 2 B.putWord16be B.putWord16le x
        marshalType T.DBusWord32 = marshalBuilder 4 B.putWord32be B.putWord32le x
        marshalType T.DBusWord64 = marshalBuilder 8 B.putWord64be B.putWord64le x
        marshalType T.DBusInt16  = marshalBuilder 2 B.putWord16be B.putWord16le $ fromIntegral (x :: Int16)
        marshalType T.DBusInt32  = marshalBuilder 4 B.putWord32be B.putWord32le $ fromIntegral (x :: Int32)
        marshalType T.DBusInt64  = marshalBuilder 8 B.putWord64be B.putWord64le $ fromIntegral (x :: Int64)

        marshalType T.DBusDouble = do
                pad 8
                (MarshalState e _ _) <- State.get
                let put = case e of
                        BigEndian -> IEEE.putFloat64be
                        LittleEndian -> IEEE.putFloat64le
                let bytes = runPut $ put x
                append bytes

        marshalType T.DBusBoolean = marshalWord32 $ if x then 1 else 0

        marshalType T.DBusString = marshalText x
        marshalType T.DBusObjectPath = marshalText . T.strObjectPath $ x

        marshalType T.DBusSignature = marshalSignature x

        marshalType (T.DBusArray _) = marshalArray x

        marshalType (T.DBusDictionary _ _) = marshalArray (T.dictionaryToArray x)

        marshalType (T.DBusStructure _) = do
                let T.Structure vs = x
                pad 8
                mapM_ marshal vs

        marshalType T.DBusVariant = do
                let rawSig = T.typeCode . T.variantType $ x
                sig <- case T.mkSignature rawSig of
                        Just x' -> return x'
                        Nothing -> E.throwError $ InvalidVariantSignature rawSig
                marshalSignature sig
                marshal x


append :: L.ByteString -> Marshal
append bytes = do
        (MarshalState e builder count) <- State.get
        let builder' = B.append builder $ B.fromLazyByteString bytes
            count' = count + fromIntegral (L.length bytes)
        State.put $ MarshalState e builder' count'

pad :: Word8 -> Marshal
pad count = do
        (MarshalState _ _ existing) <- State.get
        let padding' = fromIntegral $ padding existing count
        append $ L.replicate padding' 0

marshalBuilder :: Word8 -> (a -> B.Builder) -> (a -> B.Builder) -> a -> Marshal
marshalBuilder size be le x = do
        pad size
        (MarshalState e builder count) <- State.get
        let builder' = B.append builder $ case e of
                BigEndian -> be x
                LittleEndian -> le x
        let count' = count + fromIntegral size
        State.put $ MarshalState e builder' count'

data MarshalError
        = MessageTooLong Word64
        | ArrayTooLong Word64
        | InvalidBodySignature Text
        | InvalidVariantSignature Text
        | InvalidText Text
        deriving (Eq)

instance Show MarshalError where
        show (MessageTooLong x) = concat
                ["Message too long (", show x, " bytes)."]
        show (ArrayTooLong x) = concat
                ["Array too long (", show x, " bytes)."]
        show (InvalidBodySignature x) = concat
                ["Invalid body signature: ", show x]
        show (InvalidVariantSignature x) = concat
                ["Invalid variant signature: ", show x]
        show (InvalidText x) = concat
                ["Text cannot be marshaled: ", show x]

instance E.Error MarshalError

marshalWord32 :: Word32 -> Marshal
marshalWord32 = marshalBuilder 4 B.putWord32be B.putWord32le

marshalText :: Text -> Marshal
marshalText x = do
        bytes <- case maybeEncodeUtf8 x of
                Just x' -> return x'
                Nothing -> E.throwError $ InvalidText x
        when (L.any (== 0) bytes) $
                E.throwError $ InvalidText x
        marshalWord32 . fromIntegral . L.length $ bytes
        append bytes
        append (L.singleton 0)

marshalSignature :: T.Signature -> Marshal
marshalSignature x = do
        let bytes = encodeUtf8 . T.strSignature $ x
        let size = fromIntegral . L.length $ bytes
        append (L.singleton size)
        append bytes
        append (L.singleton 0)

marshalArray :: T.Array -> Marshal
marshalArray x = do
        (arrayPadding, arrayBytes) <- getArrayBytes (T.arrayType x) x
        let arrayLen = L.length arrayBytes
        when (arrayLen > fromIntegral C.arrayMaximumLength)
                (E.throwError $ ArrayTooLong $ fromIntegral arrayLen)
        marshalWord32 $ fromIntegral arrayLen
        append $ L.replicate arrayPadding 0
        append arrayBytes

getArrayBytes :: T.Type -> T.Array -> MarshalM (Int64, L.ByteString)
getArrayBytes T.DBusByte x = return (0, bytes) where
        Just bytes = T.arrayToBytes x

getArrayBytes itemType x = do
        let vs = T.arrayItems x
        s <- State.get
        (MarshalState _ _ afterLength) <- marshalWord32 0 >> State.get
        (MarshalState e _ afterPadding) <- pad (alignment itemType) >> State.get
        
        State.put $ MarshalState e B.empty afterPadding
        (MarshalState _ itemBuilder _) <- mapM_ marshal vs >> State.get
        
        let itemBytes = B.toLazyByteString itemBuilder
            paddingSize = fromIntegral $ afterPadding - afterLength
        
        State.put s
        return (paddingSize, itemBytes)

encodeFlags :: Set.Set M.Flag -> Word8
encodeFlags flags = foldr (.|.) 0 $ map flagValue $ Set.toList flags where
        flagValue M.NoReplyExpected = 0x1
        flagValue M.NoAutoStart     = 0x2

encodeField :: M.HeaderField -> T.Structure
encodeField (M.Path x)        = encodeField' 1 x
encodeField (M.Interface x)   = encodeField' 2 x
encodeField (M.Member x)      = encodeField' 3 x
encodeField (M.ErrorName x)   = encodeField' 4 x
encodeField (M.ReplySerial x) = encodeField' 5 x
encodeField (M.Destination x) = encodeField' 6 x
encodeField (M.Sender x)      = encodeField' 7 x
encodeField (M.Signature x)   = encodeField' 8 x

encodeField' :: T.Variable a => Word8 -> a -> T.Structure
encodeField' code x = T.Structure
        [ T.toVariant code
        , T.toVariant $ T.toVariant x
        ]

-- | Convert a 'M.Message' into a 'L.ByteString'. Although unusual, it is
-- possible for marshaling to fail -- if this occurs, an appropriate error
-- will be returned instead.

marshalMessage :: M.Message a => Endianness -> M.Serial -> a
               -> Either MarshalError L.ByteString
marshalMessage e serial msg = runMarshal marshaler e where
        body = M.messageBody msg
        marshaler = do
                sig <- checkBodySig body
                empty <- State.get
                mapM_ marshal body
                (MarshalState _ bodyBytesB _) <- State.get
                State.put empty
                marshalEndianness e
                let bodyBytes = B.toLazyByteString bodyBytesB
                marshalHeader msg serial sig
                        $ fromIntegral . L.length $ bodyBytes
                pad 8
                append bodyBytes
                checkMaximumSize

checkBodySig :: [T.Variant] -> MarshalM T.Signature
checkBodySig vs = let
        sigStr = TL.concat . map (T.typeCode . T.variantType) $ vs
        invalid = E.throwError $ InvalidBodySignature sigStr
        in case T.mkSignature sigStr of
                Just x -> return x
                Nothing -> invalid

marshalHeader :: M.Message a => a -> M.Serial -> T.Signature -> Word32
              -> Marshal
marshalHeader msg serial bodySig bodyLength = do
        let fields = M.Signature bodySig : M.messageHeaderFields msg
        marshal . T.toVariant . M.messageTypeCode $ msg
        marshal . T.toVariant . encodeFlags . M.messageFlags $ msg
        marshal . T.toVariant $ C.protocolVersion
        marshalWord32 bodyLength
        marshal . T.toVariant $ serial
        let fieldType = T.DBusStructure [T.DBusByte, T.DBusVariant]
        marshal . T.toVariant . fromJust . T.toArray fieldType
                $ map encodeField fields

marshalEndianness :: Endianness -> Marshal
marshalEndianness = marshal . T.toVariant . encodeEndianness

checkMaximumSize :: Marshal
checkMaximumSize = do
        (MarshalState _ _ messageLength) <- State.get
        when (messageLength > fromIntegral C.messageMaximumLength)
                (E.throwError $ MessageTooLong $ fromIntegral messageLength)