module Network.DBus.Message (
  Message(..),
  MessageType(..),
  Flag(..),
  dbusProtocolVersion,
  endiannessValue,
  readMessage,
  writeMessage,
  deserializeMessage,
  serializeMessage,
  methodCall
) where

import Control.Monad (liftM3, when)
import Data.Bits ((.|.), (.&.))
import Data.Char (ord)
import Data.List (foldl')
import Data.Maybe (fromJust)
import Data.Typeable (cast)
import Data.Word
import System.IO (Handle)

import qualified Control.Monad.State as S
import Data.Binary.Get
import Data.Binary.Put
import qualified Data.ByteString.Lazy as BS

import Network.DBus.Type
import Network.DBus.Value

data MessageType = MethodCall
                 | MethodReturn
                 | Error
                 | Signal
  deriving (Show, Enum, Eq)

endiannessValue :: Endianness -> Word8
endiannessValue e = fromIntegral . ord $ case e of
  LittleEndian -> 'l'
  BigEndian    -> 'B'

data Flag = NoReplyExpected
          | NoAutoStart
  deriving (Eq, Show)

flagValue :: Flag -> Word8
flagValue f = case f of NoReplyExpected -> 0x1
                        NoAutoStart     -> 0x2

flagsValue :: [Flag] -> Word8
flagsValue = foldl' (.|.) 0 . map flagValue

decodeFlags :: Word8 -> [Flag]
decodeFlags 0 = []
decodeFlags n | n .&. 0x1 /= 0 = NoReplyExpected : decodeFlags (n - 0x1)
              | n .&. 0x2 /= 0 = NoAutoStart     : decodeFlags (n - 0x2)
              | otherwise = error $ "unrecognised flag value " ++ show n

dbusProtocolVersion :: Word8
dbusProtocolVersion = 1

data Message = Message { mType :: MessageType
                       , mFlags :: [Flag]
                       , mSerial :: Word32
                       , mPath :: Maybe ObjectPath
                       , mInterface :: Maybe DString
                       , mMember :: Maybe DString
                       , mErrorName :: Maybe DString
                       , mReplySerial :: Maybe Word32
                       , mDestination :: Maybe DString
                       , mSender :: Maybe DString
                       , mBody :: [Variant]
                       }
  deriving (Eq, Show)

nativeEndianness :: Endianness
nativeEndianness = case BS.unpack . runPut $ putWord16host 1 of
                       [0, 1] -> BigEndian
                       _      -> LittleEndian


decodeFields :: [(Word8, Variant)] ->
    (Maybe ObjectPath, Maybe DString, Maybe DString, Maybe DString,
     Maybe Word32, Maybe DString, Maybe DString, Maybe Signature)
decodeFields fs = (path, iface, member, err, rs, dest, sender, sig)
    where decode (Variant x) = fromJust . cast $ x
          path   = decode `fmap` lookup 1 fs
          iface  = decode `fmap` lookup 2 fs
          member = decode `fmap` lookup 3 fs
          err    = decode `fmap` lookup 4 fs
          rs     = decode `fmap` lookup 5 fs
          dest   = decode `fmap` lookup 6 fs
          sender = decode `fmap` lookup 7 fs
          sig    = decode `fmap` lookup 8 fs

parseInit = do
    e <- getWord8 >>= \b ->
        case toEnum . fromIntegral $ b of
            'B' -> return BigEndian
            'l' -> return LittleEndian
            c -> fail $ "bad endianness " ++ show c
    t <- getWord8 >>= return . toEnum . (subtract 1) . fromIntegral
    f <- getWord8 >>= return . decodeFlags
    v <- getWord8
    let getWord32 = case e of BigEndian    -> getWord32be
                              LittleEndian -> getWord32le
    (bl, s, fl) <- liftM3 (,,) getWord32 getWord32 getWord32
    return (e, t :: MessageType, f :: [Flag], v, bl, s, fl)

-- XXX: duplication with readMessage

deserializeMessage :: BS.ByteString -> (Message, BS.ByteString)
deserializeMessage = runGet $ do
    initBuf <- lookAhead $ getLazyByteString 16
    (endianness, type_, flags, _version, bodyLength, serial, fieldsLength)
            <- parseInit

    -- XXX: Error if version isn't 1.

    fieldsBuf <- getLazyByteString . fromIntegral $ fieldsLength
    let (path, iface, member, err, rs, dest, sender, sig) =
                 decodeFields $ runDeserializer endianness
                     (do S.lift (skip 12)
                         deserializer :: Deserializer [(Word8, Variant)]) $
                     initBuf `BS.append` fieldsBuf
        bytesCount = 16 + BS.length fieldsBuf
        offset = bytesCount `mod` 8
        paddingBytes = if offset == 0 then 0 else 8 - offset

    -- skip padding
    padding <- getLazyByteString (fromIntegral paddingBytes)
    when (not $ BS.all (== 0) padding) $
        fail $ "non-null bytes in padding"

    body <- case sig of
        Nothing -> return []
        Just ts ->
            runDeserializer endianness (deserializeAs ts) `fmap`
                (getLazyByteString . fromIntegral $ bodyLength)
    rest <- getRemainingLazyByteString

    let msg = Message {
        mType = type_,
        mFlags = flags,
        mSerial = serial,
        mPath = path,
        mInterface = iface,
        mMember = member,
        mErrorName = err,
        mReplySerial = rs,
        mDestination = dest,
        mSender = sender,
        mBody = body }
    return (msg, rest)

-- XXX: duplication with deserializeMessage

readMessage :: Handle -> IO Message
readMessage handle = do
    initBuf <- BS.hGet handle 16
    let (endianness, type_, flags, _version, bodyLength, serial, fieldsLength)
            = runGet parseInit initBuf

    -- XXX: Error if version isn't 1.

    fieldsBuf <- BS.hGet handle . fromIntegral $ fieldsLength
    let (path, iface, member, err, rs, dest, sender, sig) =
                 decodeFields $ runDeserializer endianness
                     (do S.lift (skip 12)
                         deserializer :: Deserializer [(Word8, Variant)]) $
                     initBuf `BS.append` fieldsBuf
        bytesCount = 16 + BS.length fieldsBuf
        offset = bytesCount `mod` 8
        paddingBytes = if offset == 0 then 0 else 8 - offset

    -- skip padding
    padding <- BS.hGet handle (fromIntegral paddingBytes)
    when (not $ BS.all (== 0) padding) $
        fail $ "non-null bytes in padding"

    body <- case sig of
        Nothing -> return []
        Just ts -> do
            bodyBuf <- BS.hGet handle . fromIntegral $ bodyLength
            return $ runDeserializer endianness (deserializeAs ts) bodyBuf

    return $ Message {
        mType = type_,
        mFlags = flags,
        mSerial = serial,
        mPath = path,
        mInterface = iface,
        mMember = member,
        mErrorName = err,
        mReplySerial = rs,
        mDestination = dest,
        mSender = sender,
        mBody = body }

encodeFields :: Message -> [(Word8, Variant)]
encodeFields m = concat [
    extract 1 mPath m,
    extract 2 mInterface m,
    extract 3 mMember m,
    extract 4 mErrorName m,
    extract 5 mReplySerial m,
    extract 6 mDestination m,
    extract 7 mSender m,
    case mBody m of
        [] -> []
        xs -> [(8, Variant $ sig xs)]]
    where extract n rec = maybe [] (\v -> [(n, Variant v)]) . rec
          sig = Signature . map (\(Variant v) -> dtype v)

serializeMessage :: Message -> BS.ByteString
serializeMessage m = runPut $ do
    let e = nativeEndianness
    putWord8 . fromIntegral . fromEnum . endiannessValue $ e
    putWord8 . fromIntegral . (+1) . fromEnum . mType $ m
    putWord8 . flagsValue . mFlags $ m
    putWord8 dbusProtocolVersion

    let fields = runSerializer e $ do
            advanceBy 12
            serializer . encodeFields $ m
            padTo 8
        body = runSerializer e $ do
            advanceBy $ 12 + fromIntegral (BS.length fields)
            mapM_ (\(Variant v) -> serializer v) $ mBody m

    let putWord32 = case e of BigEndian -> putWord32be
                              LittleEndian -> putWord32le

    putWord32 . fromIntegral . BS.length $ body
    putWord32 . mSerial $ m
    putLazyByteString fields
    putLazyByteString body

writeMessage :: Handle -> Message -> IO ()
writeMessage handle m = BS.hPutStr handle . serializeMessage $ m

-- XXX: reorder arguments for curring?

methodCall :: DString -> DString -> DString -> ObjectPath -> [Variant] ->
              Message
methodCall interface member destination path body = Message {
    mType            = MethodCall,
    mFlags           = [],
    mSerial          = 1,
    mPath            = Just path,
    mInterface       = Just interface,
    mMember          = Just member,
    mErrorName       = Nothing,
    mReplySerial     = Nothing,
    mDestination     = Just destination,
    mSender          = Nothing,
    mBody            = body }

-- vim: sts=2 sw=2 et