-- |
-- Maintainer: Henning Guenther
--
-- A module containing functions to parse the XML specified by the
-- protocol into the internal message types.
module Network.AdHoc.ParserStrict
        (parseMessageNoValidate
        ,parseMessage
        ,parseInnerMessage) where

import Codec.Binary.Base64
import Control.Monad
import qualified Data.ByteString as BS
import Data.ByteString (pack)
import Data.Char
import Data.Time.Clock
import Data.Time.Calendar
import Data.List as List
import Data.Word

import Text.XML.HaXml.Escape
import Text.XML.HaXml.Types
import Text.XML.HaXml.Posn (Posn)

import Text.ParserCombinators.Parsec
import Text.ParserCombinators.Parsec.XML
import Text.ParserCombinators.Parsec.Error (messageString,errorMessages)

import Network.AdHoc.Channel
import Network.AdHoc.Encryption
import Network.AdHoc.Message
import Network.AdHoc.Signature
import Network.AdHoc.UserID
import Network.AdHoc.XMLRenderer(escaper)

-- | Parses a message without verifying the signature. It must only be
--   used for testing or debugging issues!
parseMessageNoValidate :: Document Posn -> Either String ExternalMessage
parseMessageNoValidate = parseMessage (\_ _ _ -> CertificateMissing (error "Try not to validate messages generated by parseMessageNoValidate, STUPID!"))

-- | Given a validation function for signatures, this function parses an XML
--   'Document' into an 'ExternalMessage'.
parseMessage :: (String -> Signature -> UserID -> SignatureStatus) -- ^ A validation function for signatures.
        -> Document Posn                 -- ^ The XML 'Document' to be parsed.
        -> Either String ExternalMessage -- ^ 'Left' @err@ on failure, @err@ describs the error. 'Right' msg on success.
parseMessage check (Document _ _ elems _)
        | name == "chat-message" = either ((Left) . show) (Right) $
                parse (parseMessage' check attrs) "" [ el | el@(CElem _ _) <- conts]
        | otherwise = Left "root element must be \"chat-message\""
        where
        Elem name attrs conts = xmlUnEscape escaper elems

-- | Parses an Obscure-message.
parseInnerMessage :: Document Posn -- ^ The XML 'Document' to be parsed.
        -> Either String (Either (UserID,RSAEncrypted String) (UTCTime,String,[Attachment])) -- ^ 'Left' @err@ in case of a parsing failure;
        -- 'Right' @msg@ on success. @msg@ either is one more obscured message or
        -- the final message to be flooded.
parseInnerMessage (Document _ _ elems _)
        | name == "chat-message" = either ((Left).show) (Right) $ parse (do
                Elem elname _ elconts <- element
                recurse (case elname of
                        "obscure" -> do
                                recv <- parseUserID "receiver"
                                txt <- stringElement "text" >>= base64dec "text"
                                return (Left (recv,RSAEncrypted $ pack txt))
                        "message" -> do
                                time <- parseTimestamp
                                optional (stringElement "messageid")
                                optional (stringElement "channel")
                                optional parseChannelID
                                txt <- namedElement "text" >>= recurse text
                                attach <- many parseUnencryptedAttachment
                                return (Right (time,txt,attach))
                        _ -> fail $ "content in obscure message must be \"obscure\" or \"message\", but not \""++elname++"\""
                        ) elconts
                ) "" [ el | el@(CElem _ _) <- conts]
        | otherwise = Left "root element must be \"chat-message\""
        where
        Elem name _ conts = xmlUnEscape stdXmlEscaper elems


parseTimestamp :: XMLParser UTCTime
parseTimestamp = stringElement "timestamp" >>= \str -> case parse parseDate "" str of
        Left err -> fail $ "error in timestamp: "++(unlines $ map messageString $ errorMessages err)
        Right res -> return res

parseID :: String -> (String -> String -> a) -> XMLParser a
parseID name constr = stringElement name >>= \str -> case break (=='@') str of
        ([],_)         -> fail "no name"
        (_,[])         -> fail "no hostname"
        (_,[_])        -> fail "no hostname"
        (user, (_:host)) -> do
                unless (all isValidUserChar user)
                        (fail $ "name "++show user++" contains illegal chars")
                unless (all isValidUserChar host)
                        (fail $ "hostname "++show host++" contains illegal chars")
                return $ constr user host
        where
        isValidUserChar c =  isAlphaNum c || c `elem` ['_','.','-']

parseUserID :: String -> XMLParser UserID
parseUserID name = parseID name UserID

parseChannelID :: XMLParser ChannelID
parseChannelID = parseID "channelid" ChannelID

parseAttachment :: XMLParser (Either Attachment EncryptedAttachment)
parseAttachment = do
        (attrs,conts) <- namedElementWithAttrs "attachment"
        recurseElements (do
                fn <- stringElement "filename"
                apptype <- stringElement "applicationtype"
                dat <- stringElement "data" >>= base64dec "data"
                case getIV attrs of
                        Nothing -> return $ Left $ Attachment fn apptype (pack dat)
                        Just iv -> case pack64 iv of
                                [riv] -> do
                                        rfn <- base64dec "filename" fn
                                        rapptype <- base64dec "apptype" apptype
                                        return $ Right $ EncryptedAttachment
                                                (Encrypted riv (pack rfn))
                                                (Encrypted riv (pack rapptype))
                                                (Encrypted riv (pack dat))
                                _     -> fail "initialization vector is corrupt"
                ) conts

parseUnencryptedAttachment :: XMLParser Attachment
parseUnencryptedAttachment = parseAttachment >>= either return (const pzero) <?> "unencrypted attachment"

parseEncryptedAttachment :: XMLParser EncryptedAttachment
parseEncryptedAttachment = parseAttachment >>= either (const pzero) return <?> "encrypted attachment"

getIV :: [Attribute] -> Maybe [Word8]
getIV attrs = List.lookup "iv" attrs >>= decode.show

parseMessage' :: (String -> Signature -> UserID -> SignatureStatus) -> [Attribute] -> XMLParser ExternalMessage
parseMessage' check attrs = do
        hashstr <- fmap getHashString $ getInput
        Elem name el_attrs conts <- element
        let getTTL = case fmap show $ List.lookup "ttl" attrs of
                Just val -> if all isDigit val
                        then return (read val)
                        else fail $ "invalid ttl value \""++val++"\"(must be a number)"
                Nothing -> return 1
        let signature' user = do
                sgntyp <- fmap (\tp -> case show tp of
                        "MD5" -> MD5
                        str -> SignUnknown str) (List.lookup "signtype" attrs)
                sgn <- fmap BS.pack $ List.lookup "signature" attrs >>= decode . show
                let rsgn = Signature sgntyp sgn
                return (rsgn,check hashstr rsgn user)
        let routed name gen recv_gen = do
                ttl <- getTTL
                user <- parseUserID name
                recv <- recv_gen
                msgid <- stringElement "messageid"
                res <- gen recv
                return $ Routed ttl user msgid res (signature' user)
        let target gen = do
                res <- routed "sender" (const gen) (return ())
                return $ Target res
        let targetMany gen = do
                res <- routed "sender" gen (many $ parseUserID "receiver")
                return $ Target res
        let targetOne gen = do
                res <- routed "sender" gen (parseUserID "receiver")
                return $ Target res
        let flood gen = do
                res <- routed "sender" (const gen) (return ())
                return $ Flood res
        let jl gen = recurseElements (flood $ do
                cname <- stringElement "channel"
                cid <- parseChannelID
                return (gen (mkChannelName cname) cid))
        case name of
                "ack" -> recurseElements (do
                        sender <- parseUserID "sender"
                        msgid <- stringElement "messageid"
                        return $ Ack sender msgid) conts
                "hello" -> recurseElements (do
                        senders <- many (parseUserID "sender")
                        optional $ stringElement "messageid"
                        vers <- stringElement "version" >>= \str -> if (all isDigit str)
                                then return (read str)
                                else fail $ "invalid version \""++str++"\"(must be a number)"
                        greeting <- option Nothing (stringElement "greeting" >>= return.Just)
                        return $ Hello senders vers greeting) conts
                "nack" -> recurseElements (target $ do
                        (sub_attrs,sub_conts) <- namedElementWithAttrs "message"
                        submsg <- recurseElements (parseMessage' check sub_attrs) sub_conts
                        case submsg of
                                Target rt -> return (Nack rt)
                                _ -> fail $ "nack must contain getcertificate-, certificate, getkey-, key-, message- or obscure-message"
                        ) conts
                "channel" -> recurseElements (flood $ do
                        cname <- stringElement "channel"
                        cid <- parseChannelID
                        descr <- stringElement "description"
                        let closed = case fmap show $ List.lookup "closed" el_attrs of
                                Just "true" -> True
                                Just "1" -> True
                                _ -> False
                        members <- many (parseUserID "member")
                        return $ Channel (mkChannelName cname) cid descr members closed) conts
                "join" -> jl Join conts
                "leave" -> jl Leave conts
                "message" -> recurseElements (targetMany $ \recv -> do
                        delay <- case fmap show $ List.lookup "delay" attrs of
                                Nothing -> return 0
                                Just str -> if all isDigit str
                                        then return $ read str
                                        else fail $ "invalid delay attribute \""++str++"\"(must be a number)"
                        time <- parseTimestamp
                        cname <- option "anonymous" $ stringElement "channel"
                        cid <- option (ChannelID "anonymous" "anonymous") parseChannelID
                        (text_attrs,text_conts) <- namedElementWithAttrs "text"
                        text <- recurse (text <|> return "") text_conts -- Allow empty text node
                        content <- case getIV text_attrs of
                                Nothing -> do
                                        attach <- many parseUnencryptedAttachment
                                        return $ UnencryptedMessage text attach
                                Just iv -> case pack64 iv of
                                        [riv] -> do
                                                attach <- many parseEncryptedAttachment
                                                rtext <- case decode text of
                                                        Nothing -> fail "text element contains invalid bas64 data"
                                                        Just r -> return r
                                                return $ EncryptedMessage
                                                        (Encrypted riv (pack rtext))
                                                        attach
                                        [] -> do
                                                attach <- many parseUnencryptedAttachment
                                                return $ UnencryptedMessage text attach
                                        _ -> fail $ "invalid initialization vector: "++show iv
                        return $ Message recv (mkChannelName cname) cid content time delay) conts >>= anonymousCheck
                "routing" -> recurseElements (do
--                        msgid <- stringElement "messageid"
                        dest <- many $ namedElement "destination" >>= recurseElements (do
                                user <- parseUserID "user"
                                hops <- stringElement "hops" >>= \str -> if all isDigit str
                                        then return (read str)
                                        else fail $ "invalid hops value \""++str++"\"(must be a number)"
                                return (user,hops))
                        return $ Routing dest) conts
                "obscure" -> recurseElements (do
                        ttl <- getTTL
                        recv <- parseUserID "receiver"
                        msgid <- stringElement "messageid"
                        text <- fmap pack $ stringElement "text" >>= base64dec "text"
                        return $ Obscure (Routed ttl recv msgid (RSAEncrypted text) ())) conts
                "getcertificate" -> recurseElements (targetOne $ \for -> return (GetCertificate for)) conts
                "certificate" -> recurseElements (targetMany $ \recv -> do
                        (for,cert) <- namedElement "certificate" >>= recurseElements (do
                                cert_user <- parseUserID "user"
                                cert_data <- stringElement "data" >>= base64dec "data"
                                return (cert_user,cert_data))
                        return $ Certificate recv for (BS.pack cert)) conts
                "getkey" -> recurseElements (targetOne $ \recv -> do
                        cname <- stringElement "channel"
                        cid <- parseChannelID
                        return $ GetKey recv (mkChannelName cname) cid) conts
                "key" -> recurseElements (targetOne $ \recv -> do
                        cname <- stringElement "channel"
                        cid <- parseChannelID
                        cipher_type <- stringElement "cipher" >>= (\str -> return $ case str of
                                "DES-CBC" -> CipherDES_CBC
                                "NONE" -> CipherNone
                                _ -> CipherUnknown str)
                        key <- fmap pack $ stringElement "key" >>= base64dec "key"
                        return $ Key recv (mkChannelName cname) cid cipher_type (RSAEncrypted key)) conts
                _ -> fail $ "unknown message type \""++name++"\""

anonymousCheck :: ExternalMessage -> XMLParser ExternalMessage
anonymousCheck x@(Target (Routed ttl user msgid (Message _ cname _ (UnencryptedMessage text attach) time delay) sig))
        = if cname == anonymous
                then return $ Flood (Routed ttl user msgid (Anonymous text attach time delay) sig)
                else return x
anonymousCheck x = return x

parseDate :: Parser UTCTime
parseDate = do
        yearFactor <- option 1 (char '-' >> return (-1))
        year <- count 4 digit >>= return.read
        char '-'
        month <- count 2 digit >>= return.read
        when (month > 12 || month == 0) (fail "invalid month")
        char '-'
        day <- count 2 digit >>= return.read
        when (day > 31 || day == 0) (fail "invalid day")
        char 'T'
        hour <- count 2 digit >>= return.read
        when (hour > 23) (fail "invalid hour")
        char ':'
        minute <- count 2 digit >>= return.read
        when (minute > 59) (fail "invalid minute")
        char ':'
        second <- count 2 digit >>= return.read
        when (second > 59) (fail "invalid seconds")
--      picos <- option (0::Int) $ char '.' >> many1 digit >>= return . read
        option (0::Int) $ char '.' >> many1 digit >>= return . read
        timezone <- option (0, 0) parseTimezone
        return $ UTCTime (fromGregorian (yearFactor * year) month day) (fromInteger $ second + (minute - (snd timezone)) * 60 + (hour - (fst timezone)) * 3600)

parseTimezone :: (Read n, Num n, Ord n) => Parser (n, n)
parseTimezone = do
        c <- anyChar
        case c of
                'Z' -> return (0,0)
                '+' -> offset
                '-' -> offset >>= \(a, b) -> return (negate a, negate b)
                _   -> fail "invalid timezone"
                where offset = do
                        hour <- count 2 digit >>= return.read
                        when (hour > 14) (fail "invalid timezone")
                        char ':'
                        minute <- count 2 digit >>= return.read
                        when (minute > 59) (fail "invalid timezone")
                        return (hour, minute)

base64dec :: String -> String -> GenParser a s [Word8]
base64dec name str = case decode str of
        Nothing -> fail $ "'"++name++"' element contains invalid base64 data"
        Just r  -> return r