-- |
-- Module:     Network.Smtp.Tools
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Helper functions and types.

{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-}

module Network.Smtp.Tools
    ( enumHandleTimeout,
      formatMsgs,
      smtpResponseLine,
      smtpResponse,
      smtpResponses,
      stringToExtension )
    where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.Set as S
import qualified Data.Vector as V
import Control.ContStuff as Cont
import Data.ByteString (ByteString)
import Data.Enumerator as E
import Data.Enumerator.List as EL
import Data.Enumerator.NetLines
import Data.List as L
import Data.Maybe
import Data.Vector (Vector)
import Network.Smtp.Types


-- | Format a 'Vector' of 'ByteString' messages from an 'SmtpResponse'
-- for output.

formatMsgs :: Vector ByteString -> String
formatMsgs = BC.unpack . BC.unwords . V.toList


-- | Read a three digit SMTP response code.

readRespCode :: ByteString -> Maybe Integer
readRespCode str = do
    guard $ B.length str >= 3
    let [a,b,c] = L.map (subtract 48 . fromIntegral . B.index str) [0,1,2]
    guard $ a < 10 && b < 10 && c < 10
    return $ 100*a + 10*b + c


-- | Determine whether the given SMTP response is a multiline response.

readRespMore :: ByteString -> Maybe Bool
readRespMore str = do
    guard $ B.length str >= 4
    let more = B.index str 3
    case more of
      45 -> return True
      32 -> return False
      _  -> empty


-- | Read the next SMTP response line from the given 'ByteString' lines
-- stream (i.e. a 'ByteString' stream converted by 'netLines').  Returns
-- 'Nothing' on EOF.  Returns @Just (Left line)@, if the next line is
-- not a proper SMTP response.  Otherwise returns @(code, more, msg)@.

smtpResponseLine ::
    Monad m =>
    MaybeT r (Iteratee ByteString m) (Either ByteString (Integer, Bool, ByteString))
smtpResponseLine = do
    line <- liftF EL.head
    let res = do
            guard $ B.length line >= 3
            code <- readRespCode line
            if B.length line > 3
              then do
                  more <- readRespMore line
                  return (code, more, B.drop 4 line)
              else return (code, False, B.empty)
    return $ maybe (Left line) Right res


-- | Read the next SMTP response from a 'netLines'-splitted 'ByteString'
-- stream.  Throws an error on protocol errors.  Returns at most the
-- given number of response messages.

smtpResponse :: forall m r. Monad m =>
                Int -> MaybeT r (Iteratee ByteString m) SmtpResponse
smtpResponse maxMsgs =
    collectResp Nothing V.empty

    where
    collectResp ::
        Maybe Integer -> Vector ByteString ->
        MaybeT r (Iteratee ByteString m) SmtpResponse
    collectResp mCode msgs' = do
        let smtpError = lift $ throwError (userError "Invalid SMTP response")
        mResp <- smtpResponseLine
        (code, more, msg) <- either (const smtpError) return mResp

        case mCode of
          Just code' -> unless (code == code') smtpError
          Nothing    -> return ()

        let msgs = V.take maxMsgs . V.snoc msgs' $ msg
        if more
          then msgs `seq` collectResp (Just code) msgs
          else msgs `seq` return (SmtpResponse code msgs)


-- | Convert a stream of 'netLines'-splitted 'ByteString' lines to a
-- stream of SMTP responses.  In case of a protocol error the
-- enumeration is aborted and an error is thrown.

smtpResponses :: forall b m. Monad m => Int -> Enumeratee ByteString SmtpResponse m b
smtpResponses maxMsgs =
    loop

    where
    loop :: Enumeratee ByteString SmtpResponse m b
    loop (Continue k) = do
        mResp <- evalMaybeT $ smtpResponse maxMsgs
        case mResp of
          Just resp -> k (Chunks [resp]) >>== loop
          Nothing   -> k EOF >>== loop
    loop step = return step


-- | Convert extension string to 'Extension' value, if the corresponding
-- extension is known.

stringToExtension :: ByteString -> Maybe Extension
stringToExtension str =
    case BC.words str of
      "AUTH" : methods ->
          Just . AuthExt . S.fromList . catMaybes . L.map authMethod $ methods

      _ -> Nothing

    where
    authMethod :: ByteString -> Maybe AuthMethod
    authMethod _ = Nothing