-- | A pure SMTP client state machine.
--
-- Data structures for representing SMTP status codes and email messages are
-- re-exported here from /Text.ParserCombinators.Parsec.Rfc2821/ and
-- /Text.ParserCombinators.Parsec.Rfc2822/ in the /hsemail/ package.

module Network.SMTP.ClientSession (
        smtpClientSession,
        SMTPState(..),
        SmtpReply(..),
        SmtpCode(..),
        SuccessCode(..),
        Category(..),
        Message(..),
        Field(..),
        NameAddr(..)
        {-, suite-}
    ) where

import Text.ParserCombinators.Parsec.Rfc2821 (
        SmtpReply(..),
        SmtpCode(..),
        SuccessCode(..),
        Category(..),
        reply
    )
import Text.ParserCombinators.Parsec.Rfc2822 (
        Message(..),
        Field(..),
        NameAddr(..)
    )
import Data.Maybe
import qualified Data.Set as Set
import Prelude hiding (fail)
import Data.List
import System.Time
import System.Locale


{-import Test.Framework
import Test.Framework.Providers.HUnit
import Test.HUnit-}

{-suite = testGroup "MetaBarter.SMTP.SMTP" [
        testCase "dotAtomAllowed"   test_dotAtomAllowed,
        testCase "quoted_string"    test_quoted_string
    ]-}

data SMTPState = SMTPState {

        -- | Step 1. Caller must send any lines queued up in this list to the SMTP
        -- server.  They do not have end-of-line characters, so you must add
        -- \"\\r\\n\" on the end (both characters are required by RFC2821 - do not
        -- just send \"\\n\").
        smtpOutQueue :: [String],

        -- | Step 2. Check if this is True, which indicates that the SMTP session
        -- has completed successfully and there is no more work to do.
        smtpSuccess  :: Bool,

        -- | Step 3. Check if this is Just err, which indicates that a protocol error
        -- has occurred, and that the caller must terminate the session.
        smtpFailure  :: Maybe String,

        -- | Step 4. The caller should wait for a line from the SMTP server,
        -- strip the \"\\r\\n\" end-of-line characters, and pass the stripped
        -- line to this function for processing.  Go to step 1.
        smtpReceive  :: String -> SMTPState -> SMTPState,

        -- | A list containing a failure status for each message that has been sent so far,
        -- where each element corresponds to one in the list of messages.
        -- If the SMTP session does not complete successfully, then this list is
        -- likely to be shorter than the input messages list.  When smtpSuccess is
        -- true, this list is guaranteed to be the same length as the list of input
        -- messages.
        -- /Nothing/ means success, and /Just x/ is a failure status returned by
        -- the SMTP server.
        smtpStatuses :: [Maybe SmtpReply]
    }

send :: String -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
send txt cont state =
    cont $
        state {
                smtpOutQueue = txt:smtpOutQueue state
            }

-- | A 'null' smtpReceive callback that discards anything given to it.
nullReceive :: String -> SMTPState -> SMTPState
nullReceive _ state = state

receive :: (String -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
receive cont state =
    state {
            -- Reverse the out queue before passing it to the caller because we
            -- have been assembling it backwards.
            smtpOutQueue = reverse $ smtpOutQueue state,
            smtpReceive = \line state -> cont line $ state {
                    smtpOutQueue = [],
                    smtpReceive = nullReceive
                }
        }

fail :: String -> SMTPState -> SMTPState
fail errorText state =
    state {
            smtpOutQueue = [],
            smtpFailure = Just errorText
        }

success :: SMTPState -> SMTPState
success state =
    state {
            smtpOutQueue = [],
            smtpSuccess = True
        }

equals :: SuccessCode -> Category -> SmtpReply -> Bool
equals code cat reply =
    case reply of
        Reply (Code gotCode gotCat _) _ | gotCode == code && gotCat == cat -> True
        otherwise -> False

check :: (SmtpReply -> Bool) -> String -> SmtpReply -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
check isOK descr reply cont =
    if isOK reply
        then cont
        else
            fail $ "SMTP error: got "++
                    (cleanUp $ show reply)++
                    " when I expected "++descr

-- | Squish the SMTP reply description into one line.
cleanUp :: String -> String
cleanUp =
    map (\x -> if x == '\n' then '/' else x) .
    reverse .
    dropWhile (\x -> x == '\n') .
    reverse .
    filter (/= '\r')

maybeRead :: Read a => String -> Maybe a
maybeRead s = case reads s of
    [(x, "")] -> Just x
    _         -> Nothing

for :: [a]
       -- loop body, passed: value next
    -> (a -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState)
    -> (SMTPState -> SMTPState)
    -> SMTPState -> SMTPState
for [] _ cont = cont
for (x:xs) body cont =
    body x $
    for xs body cont

putStatuses :: [Maybe SmtpReply] -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
putStatuses statuses cont state = cont $ state {smtpStatuses = statuses}

-- | Receive an SMTP reply, e.g.
-- 250-worked
-- 250-like
-- 250 a charm
receiveReply :: (SmtpReply -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
receiveReply cont =
        rec [] $ \revMsgs@(final:_) ->
        if length final < 4 || final !! 3 /= ' ' then
            fail ("malformed SMTP reply: "++final)
        else
            case maybeRead (take 3 final) of
                Just code ->
                    cont $ mkReply code (reverse $ map (drop 4) revMsgs)
                Nothing ->
                    fail ("Malformed SMTP reply: "++final)
    where
        rec :: [String] -> ([String] -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
        rec msgs cont =
            receive $ \line ->
            if length line < 4 then
                fail ("malformed SMTP reply: "++line)
            else if line !! 3 == '-' then
                rec (line:msgs) cont
            else
                cont (line:msgs)

        mkReply :: Int -> [String] -> SmtpReply
        mkReply code msgs = reply (code `div` 100) ((code `div` 10) `mod` 10) (code `mod` 10) msgs

equalsMailOK :: SmtpReply -> Bool
equalsMailOK = equals Success MailSystem

checkMailOK :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
checkMailOK cont =
    receiveReply $ \reply ->
    check equalsMailOK "\"mail system OK\" (code 25x)" reply $
    cont

checkConnectionOK :: (SMTPState -> SMTPState) -> SMTPState -> SMTPState
checkConnectionOK cont =
    receiveReply $ \reply ->
    check (equals Success Connection) "\"connection OK\" (code 22x)" reply $
    cont

equalsDataOK :: SmtpReply -> Bool
equalsDataOK = equals IntermediateSuccess MailSystem

-- | Construct a pure state machine for an SMTP client session.  Caller must
-- handle I/O.  The message body may use either \"\\n\" or \"\\r\\n\" as an
-- end-of-line marker.
smtpClientSession :: String     -- ^ Domain name used in EHLO command
                  -> [Message]  -- ^ List of messges to send
                  -> SMTPState
smtpClientSession domain messages = talk domain messages $ SMTPState {
            smtpOutQueue = [],
            smtpReceive  = nullReceive,
            smtpSuccess  = False,
            smtpFailure  = Nothing,
            smtpStatuses = []
        }

-- Continuation passing style.
talk :: String -> [Message] -> SMTPState -> SMTPState
talk domain messages =
        checkConnectionOK $
        send ("EHLO "++domain) $
        checkMailOK $
        sendMessages messages [] $
        send "QUIT" $
        checkConnectionOK $
        success
    where
        sendMessages :: [Message] -> [Maybe SmtpReply] -> (SMTPState -> SMTPState) -> SMTPState -> SMTPState
        sendMessages [] _ cont = cont
        sendMessages (message:rest) statuses cont =
            sendMessage message $ \status ->
            let statuses' = status:statuses in   -- collate statuses backwards
            putStatuses (reverse statuses') $    -- reverse to store in correct order
            sendMessages rest statuses' cont

        sendMessage :: Message -> (Maybe SmtpReply -> SMTPState -> SMTPState) -> SMTPState -> SMTPState
        sendMessage message cont =
            let Message fields _ = message

                froms = map (\(NameAddr _ addr) -> addr) $
                    concatMap (\f ->
                        case f of
                            From from -> from
                            _         -> []) fields

                tos = map (\(NameAddr _ addr) -> addr) $
                    concatMap (\f ->
                        case f of
                            To to  -> to
                            Cc to  -> to
                            Bcc to -> to
                            _      -> []) fields in

            if null froms then
                fail "email contains no From: field"
            else if null tos then
                fail "email contains no To:, Cc: or Bcc: field"
            else (
                    send ((("MAIL FROM:"++) . angle_addr (head froms)) "") $
                    receiveReply $ \reply ->
                    if not $ equalsMailOK reply then
                        cont (Just reply)  -- failure status
                    else (
                            for tos (\to next ->
                                    send ((("RCPT TO:"++) . angle_addr to) "") $
                                    receiveReply $ \reply ->
                                    if not $ equalsMailOK reply
                                        then cont (Just reply)  -- failure status
                                        else next
                                ) $
                            send "DATA" $
                            receiveReply $ \reply ->
                            if not $ equalsDataOK reply then
                                cont (Just reply)  -- failure status
                            else (
                                    let msgLines = formatMessage message in
                                    for msgLines (\line cont ->
                                            send (if line == "." then ". " else line) $
                                            cont
                                        ) $
                                    send "." $
                                    receiveReply $ \reply ->
                                    if not $ equalsMailOK reply then
                                        cont (Just reply)  -- failure status
                                    else
                                        cont Nothing  -- success
                                )
                        )
                )

atext_alphabet = Set.fromList $
    "abcdefghijklmnopqrstuvwxyz"++
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ"++
    "0123456789"++
    "!#$%&'*+-/=?^_`{|}~"

atomAllowed :: String -> Bool
atomAllowed "" = False
atomAllowed txt = all (`Set.member` atext_alphabet) txt

dotAtomAllowed :: String -> Bool
dotAtomAllowed str =
    let (at, rest) = break (== '.') str in
    if not $ atomAllowed at then
        False
    else
        if null rest then
            True
        else
            dotAtomAllowed (tail rest)

{-
test_dotAtomAllowed = do
    assertEqual "" False (dotAtomAllowed "")
    assertEqual "a" True (dotAtomAllowed "a")
    assertEqual "a<" False (dotAtomAllowed "a<")
    assertEqual ".a" False (dotAtomAllowed ".a")
    assertEqual "abc.012$" True (dotAtomAllowed "abc.012$")
    assertEqual "01.%^.xyzzy" True (dotAtomAllowed "01.%^.xyzzy")
    assertEqual "01..xyzzy" False (dotAtomAllowed "01..xyzzy")
    assertEqual "01.xyzzy." False (dotAtomAllowed "01.xyzzy.")
-}

addr_spec :: String -> ShowS
addr_spec addr =
    let (local, at_domain) = break (=='@') addr
        domain = if "@" `isPrefixOf` at_domain then tail at_domain else at_domain in
    dotatom_or_quoted local . ("@"++) . dotatom_or_domain_literal domain

angle_addr :: String -> ShowS
angle_addr addr =
    ("<"++) . addr_spec addr . (">"++)

msg_id :: String -> ShowS
msg_id mid = angle_addr mid

atom_or_quoted :: String -> ShowS
atom_or_quoted text =
    if atomAllowed text then
        (text++)
    else
        quoted_string text

phrase = atom_or_quoted
display_name = phrase

dotatom_or_quoted :: String -> ShowS
dotatom_or_quoted text =
    if dotAtomAllowed text then
        (text++)
    else
        quoted_string text

quoted_string :: String -> ShowS
quoted_string text = ("\""++) . (quoted_ text++) . ("\""++)
    where
        quoted_ [] = []
        quoted_ (x:xs) =
            case x of
                '\\' -> '\\':'\\':quoted_ xs
                '"'  -> '\\':'"':quoted_ xs
                _    -> x:quoted_ xs

{-
test_quoted_string = do
    assertEqual "" (quoted_string "" "") "\"\""
    assertEqual "Hello" (quoted_string "Hello" "") "\"Hello\""
    assertEqual "Say, \"Hello\"" (quoted_string "Say, \"Hello\"" "") "\"Say, \\\"Hello\\\"\""
    assertEqual "Backslash\\" (quoted_string "Backslash\\" "") "\"Backslash\\\\\""
-}

dotatom_or_domain_literal :: String -> ShowS
dotatom_or_domain_literal text =
    if dotAtomAllowed text then
        (text++)
    else
        domain_literal text

domain_literal :: String -> ShowS
domain_literal text = ("["++) . (quoted_ text++) . ("]"++)
    where
        quoted_ [] = []
        quoted_ (x:xs) =
            case x of
                '\\' -> '\\':'\\':quoted_ xs
                '['  -> '\\':'[':quoted_ xs
                ']'  -> '\\':']':quoted_ xs
                _    -> x:quoted_ xs

name_addr :: NameAddr -> ShowS
name_addr (NameAddr mName addr) =
    (case mName of
         Just name -> display_name name . (" "++)
         Nothing   -> id) .
    angle_addr addr

address_list :: [NameAddr] -> ShowS
address_list addrs =
    foldr (.) id $ intersperse (",\n"++) $ map name_addr addrs

formatMessage :: Message -> [String]
formatMessage (Message fields body) =
    concatMap (indent . formatField) fields ++
    [""] ++
    map (reverse . dropWhile (=='\r') . reverse) (lines body)

indent :: String -> [String]
indent text =
    case lines text of
        [] -> []
        (l:ls) -> l:map ("        "++) ls

formatField :: Field -> String
formatField (OptionalField name value) = name ++ ": " ++ value
formatField (From addrs)        = (("From: "++) . address_list addrs) ""
formatField (Sender addr)       = (("From: "++) . name_addr addr) ""
formatField (ReturnPath txt)    = "Return-Path: "++txt
formatField (ReplyTo addrs)     = (("Reply-To: "++) . address_list addrs) ""
formatField (To addrs)          = (("To: "++) . address_list addrs) ""
formatField (Cc addrs)          = (("Cc: "++) . address_list addrs) ""
formatField (Bcc addrs)         = (("Bcc: "++) . address_list addrs) ""
formatField (MessageID mid)     = (("Message-ID: "++) . msg_id mid) ""
formatField (InReplyTo mids)    = (("In-Reply-To: "++) . foldr (.) id (map msg_id mids)) ""
formatField (References mids)   = (("References: "++) . foldr (.) id (map msg_id mids)) ""
formatField (Subject txt)       = "Subject: "++txt
formatField (Comments txt)      = "Comments: "++txt
formatField (Keywords kws)      = (("Keywords: "++) . foldr (.) id (intersperse (",\n"++) $ map phrase (concat kws))) ""
formatField (Date date)         = "Date: "++formatDate date
formatField (ResentDate date)   = "Resent-Date: "++formatDate date
formatField (ResentFrom addrs)  = (("Resent-From: "++) . address_list addrs) ""
formatField (ResentSender addr) = (("Resent-From: "++) . name_addr addr) ""
formatField (ResentTo addrs)    = (("Resent-To: "++) . address_list addrs) ""
formatField (ResentCc addrs)    = (("Resent-Cc: "++) . address_list addrs) ""
formatField (ResentBcc addrs)   = (("Resent-Bcc: "++) . address_list addrs) ""
formatField (ResentMessageID mid) = (("Resent-Message-ID: "++) . msg_id mid) ""
formatField (ResentReplyTo addrs) = (("Resent-Reply-To: "++) . address_list addrs) ""
formatField (Received (ps, date)) = (("Received: "++) . pairs ps . (";"++) . (formatDate date++)) ""
formatField (ObsReceived ps)    = (("Received: "++) . pairs ps) ""

formatDate :: CalendarTime -> String
formatDate ct = formatCalendarTime defaultTimeLocale "%a, %d %b %Y %H:%M:%S " ct ++ formatTimeZone (ctTZ ct)

formatTimeZone :: Int -> String
formatTimeZone offset =
        if offset < 0 then
            "-"++ftz (-offset)
        else
            "+"++ftz offset
    where
        ftz offset = dig2 (offset `div` 3600) ++ dig2 ((offset `div` 60) `mod` 60)
        dig2 n = (\n -> if n < 10 then "0"++show n else show n) (n `mod` 100)

pairs :: [(String, String)] -> ShowS
pairs ps = foldr (.) id $ intersperse ("\n"++) $ map pair ps
    where
        pair (name, value) =
            (name++) . (" "++) . (value++)