module Network.SMTP.ClientSession (
        smtpClientSession,
        SMTPState(..)
        {-, suite-}
    ) where

import Text.ParserCombinators.Parsec.Rfc2821 (
        SmtpReply(..),
        SmtpCode(..),
        SuccessCode(..),
        Category(..),
        reply
    )
import Text.ParserCombinators.Parsec.Rfc2822 (
        Message(..),
        Field(..),
        NameAddr(..)
    )
import Data.Maybe
import qualified Data.Set as Set
import Prelude hiding (fail, return)
import Data.List
import System.Time
import System.Locale


{-import Test.Framework
import Test.Framework.Providers.HUnit
import Test.HUnit-}

{-suite = testGroup "MetaBarter.SMTP.SMTP" [
        testCase "dotAtomAllowed"   test_dotAtomAllowed,
        testCase "quoted_string"    test_quoted_string
    ]-}

data SMTPState = SMTPState {

        -- | Caller must output any lines queued up in this list, making sure to
        -- clear them in the state passed back to smtpReceive.  These lines are in
        -- reverse order, so the caller must reverse them before processing.
        -- They do not have end-of-line characters, so the caller must add \"\\r\\n\"
        -- on the end (as required by RFC2821 - not just \"\\n\").
        smtpOutQueue :: [String],

        -- | When there is nothing to send, the caller should wait for a line from
        -- the SMTP server, strip any end-of-line characters, and pass it to this
        -- function for processing.
        smtpReceive  :: String -> SMTPState -> SMTPState,

        -- | When True, this flag indicates that the SMTP session has completed
        -- successfully and there is no more work to do.
        smtpSuccess  :: Bool,

        -- | When Just err, this indicates that a protocol error has occurred, and
        -- that the caller must terminate the session.
        smtpFailure  :: Maybe String,

        -- | The number of emails successfully sent so far.
        smtpSent     :: Int
    }

send :: String -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
send txt return state =
    return $
        state {
                smtpOutQueue = txt:smtpOutQueue state
            }

nullReceive :: String -> SMTPState -> SMTPState
nullReceive _ state = state

receive :: (String -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
receive cont state =
    state {
            smtpReceive = \line state -> cont line $ state {smtpReceive = nullReceive}
        }

fail :: String -> SMTPState -> SMTPState
fail errorText state =
    state {
            smtpFailure = Just errorText
        }

success :: SMTPState -> SMTPState
success state = state {smtpSuccess = True}

check :: SuccessCode -> Category -> String -> SmtpReply -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
check code cat descr reply return =
    case reply of
        Reply (Code gotCode gotCat _) _ | gotCode == code && gotCat == cat -> return
        _ -> fail $ "SMTP error: got "++
                    (cleanUp $ show reply)++
                    " when I expected "++descr

cleanUp :: String -> String
cleanUp =
    map (\x -> if x == '\n' then '/' else x) .
    reverse .
    dropWhile (\x -> x == '\n') .
    reverse .
    filter (/= '\r')
 
maybeRead :: Read a => String -> Maybe a
maybeRead s = case reads s of
    [(x, "")] -> Just x
    _         -> Nothing

for :: [a] -> (a -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState) -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
for [] _ return = return
for (x:xs) code return =
    code x $
    for xs code return

incSent :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
incSent return state = return $ state {smtpSent = smtpSent state + 1}

-- | Receive an SMTP reply, e.g.
-- 250-worked
-- 250-like
-- 250 a charm
receiveReply :: (SmtpReply -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
receiveReply return =
        rec [] $ \revMsgs@(final:_) ->
        if length(final) < 4 || final !! 3 /= ' ' then
            fail ("malformed SMTP reply: "++final)
        else
            case maybeRead (take 3 final) of
                Just code ->
                    return $ mkReply code (reverse $ map (drop 4) revMsgs)
                Nothing ->
                    fail ("Malformed SMTP reply: "++final)
    where
        rec :: [String] -> ([String] -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
        rec msgs return =
            receive $ \line ->
            if length(line) < 4 then
                fail ("malformed SMTP reply: "++line)
            else if line !! 3 == '-' then
                rec (line:msgs) return
            else
                return (line:msgs)

        mkReply :: Int -> [String] -> SmtpReply
        mkReply code msgs = reply (code `div` 100) ((code `div` 10) `mod` 10) (code `mod` 10) msgs

checkMailOK :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
checkMailOK return =
    receiveReply $ \reply ->
    check Success MailSystem "\"mail system OK\" (code 25x)" reply $
    return

checkConnectionOK :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
checkConnectionOK return =
    receiveReply $ \reply ->
    check Success Connection "\"connection OK\" (code 22x)" reply $
    return

checkDataOK :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
checkDataOK return =
    receiveReply $ \reply ->
    check IntermediateSuccess MailSystem "\"connection OK\" (code 35x)" reply $
    return

-- | Pure state machine for an SMTP client session.  Caller must handle I/O.
-- The message body may contain either \"\n\" or \"\\r\\n\" for an end-of-line
-- marker. All are stripped before passing to caller for dispatch.
smtpClientSession :: String     -- ^ Domain name used in EHLO command
                  -> [Message]  -- ^ List of messges to send
                  -> SMTPState
smtpClientSession domain messages = talk domain messages $ SMTPState {
            smtpOutQueue = [],
            smtpReceive  = nullReceive,
            smtpSuccess  = False,
            smtpFailure  = Nothing,
            smtpSent     = 0
        }

-- Continuation passing style.
talk :: String -> [Message] -> SMTPState -> SMTPState
talk domain messages =
        checkConnectionOK $
        send ("EHLO "++domain) $
        checkMailOK $
        sendMessages messages $
        send "QUIT" $
        checkConnectionOK $
        success
    where
        sendMessages :: [Message] -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
        sendMessages messages return =
            for messages (\message cont ->
                    sendMessage message cont
                ) $
            return

        sendMessage :: Message -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
        sendMessage message return =
            let Message fields _ = message

                froms = map (\(NameAddr _ addr) -> addr) $
                    concatMap (\f ->
                        case f of
                            From from -> from
                            _         -> []) fields

                tos = map (\(NameAddr _ addr) -> addr) $
                    concatMap (\f ->
                        case f of
                            To to  -> to
                            Cc to  -> to
                            Bcc to -> to
                            _      -> []) fields in

            if null froms then
                fail "email contains no From: field"
            else if null tos then
                fail "email contains to To: field"
            else (
                    send ((("MAIL FROM:"++) . angle_addr (head froms)) "") $
                    checkMailOK $
                    for tos (\to return ->
                            send ((("RCPT TO:"++) . angle_addr to) "") $
                            checkMailOK $
                            return
                        ) $
                    send "DATA" $
                    checkDataOK $
                    let msgLines = formatMessage message in
                    for msgLines (\line return ->
                            send (if line == "." then ". " else line) $
                            return
                        ) $
                    send "." $
                    checkMailOK $
                    incSent
                    return
                )

atext_alphabet = Set.fromList $
    "abcdefghijklmnopqrstuvwxyz"++
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ"++
    "0123456789"++
    "!#$%&'*+-/=?^_`{|}~"

atomAllowed :: String -> Bool
atomAllowed "" = False
atomAllowed txt = all (`Set.member` atext_alphabet) txt

dotAtomAllowed :: String -> Bool
dotAtomAllowed str =
    let (at, rest) = break (== '.') str in
    if not $ atomAllowed at then
        False
    else
        if null rest then
            True
        else
            dotAtomAllowed (tail rest)

{-
test_dotAtomAllowed = do
    assertEqual "" False (dotAtomAllowed "")
    assertEqual "a" True (dotAtomAllowed "a")
    assertEqual "a<" False (dotAtomAllowed "a<")
    assertEqual ".a" False (dotAtomAllowed ".a")
    assertEqual "abc.012$" True (dotAtomAllowed "abc.012$")
    assertEqual "01.%^.xyzzy" True (dotAtomAllowed "01.%^.xyzzy")
    assertEqual "01..xyzzy" False (dotAtomAllowed "01..xyzzy")
    assertEqual "01.xyzzy." False (dotAtomAllowed "01.xyzzy.")
-}

addr_spec :: String -> ShowS
addr_spec addr =
    let (local, at_domain) = break (=='@') addr
        domain = if "@" `isPrefixOf` at_domain then tail at_domain else at_domain in
    dotatom_or_quoted local . ("@"++) . dotatom_or_domain_literal domain

angle_addr :: String -> ShowS
angle_addr addr =
    ("<"++) . addr_spec addr . (">"++)

msg_id :: String -> ShowS
msg_id mid = angle_addr mid

atom_or_quoted :: String -> ShowS
atom_or_quoted text =
    if atomAllowed text then
        (text++)
    else
        quoted_string text

phrase = atom_or_quoted
display_name = phrase

dotatom_or_quoted :: String -> ShowS
dotatom_or_quoted text =
    if dotAtomAllowed text then
        (text++)
    else
        quoted_string text

quoted_string :: String -> ShowS
quoted_string text = ("\""++) . (quoted_ text++) . ("\""++)
    where
        quoted_ [] = []
        quoted_ (x:xs) =
            case x of
                '\\' -> '\\':'\\':quoted_ xs
                '"'  -> '\\':'"':quoted_ xs
                _    -> x:quoted_ xs

{-
test_quoted_string = do
    assertEqual "" (quoted_string "" "") "\"\""
    assertEqual "Hello" (quoted_string "Hello" "") "\"Hello\""
    assertEqual "Say, \"Hello\"" (quoted_string "Say, \"Hello\"" "") "\"Say, \\\"Hello\\\"\""
    assertEqual "Backslash\\" (quoted_string "Backslash\\" "") "\"Backslash\\\\\""
-}

dotatom_or_domain_literal :: String -> ShowS
dotatom_or_domain_literal text =
    if dotAtomAllowed text then
        (text++)
    else
        domain_literal text

domain_literal :: String -> ShowS
domain_literal text = ("["++) . (quoted_ text++) . ("]"++)
    where
        quoted_ [] = []
        quoted_ (x:xs) =
            case x of
                '\\' -> '\\':'\\':quoted_ xs
                '['  -> '\\':'[':quoted_ xs
                ']'  -> '\\':']':quoted_ xs
                _    -> x:quoted_ xs

name_addr :: NameAddr -> ShowS
name_addr (NameAddr mName addr) =
    (case mName of
         Just name -> display_name name . (" "++)
         Nothing   -> id) .
    angle_addr addr

address_list :: [NameAddr] -> ShowS
address_list addrs =
    foldr (.) id $ intersperse (",\n"++) $ map name_addr addrs

formatMessage :: Message -> [String]
formatMessage (Message fields body) =
    concatMap (indent . formatField) fields ++
    [""] ++
    map (reverse . dropWhile (=='\r') . reverse) (lines body)

indent :: String -> [String]
indent text =
    case lines text of
        [] -> []
        (l:ls) -> l:map ("        "++) ls

formatField :: Field -> String
formatField (OptionalField name value) = name ++ ": " ++ value
formatField (From addrs)        = (("From: "++) . address_list addrs) ""
formatField (Sender addr)       = (("From: "++) . name_addr addr) ""
formatField (ReturnPath txt)    = "Return-Path: "++txt
formatField (ReplyTo addrs)     = (("Reply-To: "++) . address_list addrs) ""
formatField (To addrs)          = (("To: "++) . address_list addrs) ""
formatField (Cc addrs)          = (("Cc: "++) . address_list addrs) ""
formatField (Bcc addrs)         = (("Bcc: "++) . address_list addrs) ""
formatField (MessageID mid)     = (("Message-ID: "++) . msg_id mid) ""
formatField (InReplyTo mids)    = (("In-Reply-To: "++) . foldr (.) id (map msg_id mids)) ""
formatField (References mids)   = (("References: "++) . foldr (.) id (map msg_id mids)) ""
formatField (Subject txt)       = "Subject: "++txt
formatField (Comments txt)      = "Comments: "++txt
formatField (Keywords kws)      = (("Keywords: "++) . foldr (.) id (intersperse (",\n"++) $ map phrase (concat kws))) ""
formatField (Date date)         = "Date: "++formatDate date
formatField (ResentDate date)   = "Resent-Date: "++formatDate date
formatField (ResentFrom addrs)  = (("Resent-From: "++) . address_list addrs) ""
formatField (ResentSender addr) = (("Resent-From: "++) . name_addr addr) ""
formatField (ResentTo addrs)    = (("Resent-To: "++) . address_list addrs) ""
formatField (ResentCc addrs)    = (("Resent-Cc: "++) . address_list addrs) ""
formatField (ResentBcc addrs)   = (("Resent-Bcc: "++) . address_list addrs) ""
formatField (ResentMessageID mid) = (("Resent-Message-ID: "++) . msg_id mid) ""
formatField (ResentReplyTo addrs) = (("Resent-Reply-To: "++) . address_list addrs) ""
formatField (Received (ps, date)) = (("Received: "++) . pairs ps . (";"++) . (formatDate date++)) ""
formatField (ObsReceived ps)    = (("Received: "++) . pairs ps) ""

formatDate :: CalendarTime -> String
formatDate ct = formatCalendarTime defaultTimeLocale "%a, %d %b %Y %H:%M:%S " ct ++ formatTimeZone (ctTZ ct)

formatTimeZone :: Int -> String
formatTimeZone offset =
        if offset < 0 then
            "-"++ftz (-offset)
        else
            "+"++ftz offset
    where
        ftz offset = dig2 (offset `div` 3600) ++ dig2 ((offset `div` 60) `mod` 60)
        dig2 n = (\n -> if n < 10 then "0"++show n else show n) (n `mod` 100)

pairs :: [(String, String)] -> ShowS
pairs ps = foldr (.) id $ intersperse ("\n"++) $ map pair ps
    where
        pair (name, value) =
            (name++) . (" "++) . (value++)