-- |
-- Module:     Network.Smtp.Tools
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- Helper functions and types.

{-# LANGUAGE ScopedTypeVariables #-}

module Network.Smtp.Tools
    ( enumHandleTimeout,
      formatMsgs,
      netLine,
      netLines,
      responseLines,
      smtpResponseLine,
      smtpResponse,
      smtpResponses,
      stringToExtension )
    where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.Vector as V
import Control.ContStuff as Cont
import Data.Enumerator as E
import Data.Enumerator.Binary as EB
import Data.Enumerator.List as EL
import Data.ByteString (ByteString)
import Data.List as L
import Data.Vector (Vector)
import Data.Word
import Network.Smtp.Types
import System.IO
import System.IO.Error as IOErr


-- | Enumerate from a handle with the given buffer size (first argument)
-- and timeout in milliseconds (second argument).  If the timeout is
-- exceeded an exception is thrown via 'throwError'.

enumHandleTimeout :: forall b m. MonadIO m =>
                     Int -> Int -> Handle -> Enumerator ByteString m b
enumHandleTimeout bufSize timeout h = loop
    where
    loop :: Enumerator ByteString m b
    loop (Continue k) = do
          mHaveInput <- liftIO $ IOErr.try (hWaitForInput h timeout)
          case mHaveInput of
            Left err
                | isEOFError err -> continue k
                | otherwise      -> throwError err
            Right False -> throwError $ userError "Handle timed out"
            Right True  -> do
                mStr <- liftIO $ IOErr.try (B.hGetNonBlocking h bufSize)
                str <- either throwError return mStr
                if B.null str
                  then continue k
                  else k (Chunks [str]) >>== loop
    loop step = returnI step


-- | Format a 'Vector' of 'ByteString' messages from an 'SmtpResponse'
-- for output.

formatMsgs :: Vector ByteString -> String
formatMsgs = BC.unpack . BC.unwords . V.toList


-- | Savely read a line with the given maximum length.  If a longer line
-- is enumerated, the excess data is dropped in constant space.  Returns
-- 'Nothing' on EOF.

netLine :: forall m r. Monad m => Int -> MaybeT r (Iteratee ByteString m) ByteString
netLine n =
    lift (EB.dropWhile isEol) >> netLine' n

    where
    isEol :: Word8 -> Bool
    isEol 10 = True
    isEol 13 = True
    isEol _  = False

    isNotEol :: Word8 -> Bool
    isNotEol = not . isEol

    netLine' :: Int -> MaybeT r (Iteratee ByteString m) ByteString
    netLine' 0 = B.empty <$ lift (EB.dropWhile isNotEol)
    netLine' n = do
        c <- liftF EB.head
        if isNotEol c
          then B.cons c <$> netLine' (n-1)
          else return B.empty


-- | Convert a stream of bytes to a stream of lines with the given
-- maximum length.  Longer lines are silently truncated in constant
-- space.

netLines :: forall b m. Monad m => Int -> Enumeratee ByteString ByteString m b
netLines maxLen = loop
    where
    loop :: Enumeratee ByteString ByteString m b
    loop (Continue k) = do
        mLine <- evalMaybeT $ netLine maxLen
        case mLine of
          Just line -> k (Chunks [line]) >>== loop
          Nothing   -> k EOF >>== loop
    loop step = return step


-- | Read a three digit SMTP response code.

readRespCode :: ByteString -> Maybe Integer
readRespCode str = do
    guard $ B.length str >= 3
    let [a,b,c] = L.map (subtract 48 . fromIntegral . B.index str) [0,1,2]
    guard $ a < 10 && b < 10 && c < 10
    return $ 100*a + 10*b + c


-- | Determine whether the given SMTP response is a multiline response.

readRespMore :: ByteString -> Maybe Bool
readRespMore str = do
    guard $ B.length str >= 4
    let more = B.index str 3
    case more of
      45 -> return True
      32 -> return False
      _  -> empty


-- | Composition of all 'Enumeratee's, which are needed to convert a raw
-- 'ByteString' stream to an 'SmtpResponse' stream.  This function takes
-- the maximum line length and the response line limit as its first two
-- parameters.

responseLines :: Monad m =>
                 Int -> Int -> Iteratee SmtpResponse m b -> Iteratee ByteString m b
responseLines maxLine maxMsgs c =
    joinI $ netLines maxLine $$
    joinI $ smtpResponses maxMsgs $$
    c


-- | Read the next SMTP response line from the given 'ByteString' lines
-- stream (i.e. a 'ByteString' stream converted by 'netLines').  Returns
-- 'Nothing' on EOF.  Returns @Just (Left line)@, if the next line is
-- not a proper SMTP response.  Otherwise returns @(code, more, msg)@.

smtpResponseLine ::
    Monad m =>
    MaybeT r (Iteratee ByteString m) (Either ByteString (Integer, Bool, ByteString))
smtpResponseLine = do
    line <- liftF EL.head
    let res = do
            guard $ B.length line >= 3
            code <- readRespCode line
            if B.length line > 3
              then do
                  more <- readRespMore line
                  return (code, more, B.drop 4 line)
              else return (code, False, B.empty)
    return $ maybe (Left line) Right res


-- | Read the next SMTP response from a 'netLines'-splitted 'ByteString'
-- stream.  Throws an error on protocol errors.  Returns at most the
-- given number of response messages.

smtpResponse :: forall m r. Monad m =>
                Int -> MaybeT r (Iteratee ByteString m) SmtpResponse
smtpResponse maxMsgs =
    collectResp Nothing V.empty

    where
    collectResp ::
        Maybe Integer -> Vector ByteString ->
        MaybeT r (Iteratee ByteString m) SmtpResponse
    collectResp mCode msgs' = do
        let smtpError = lift $ throwError (userError "Invalid SMTP response")
        mResp <- smtpResponseLine
        (code, more, msg) <- either (const smtpError) return mResp

        case mCode of
          Just code' -> unless (code == code') smtpError
          Nothing    -> return ()

        let msgs = V.take maxMsgs . V.snoc msgs' $ msg
        if more
          then msgs `seq` collectResp (Just code) msgs
          else msgs `seq` return (SmtpResponse code msgs)


-- | Convert a stream of 'netLines'-splitted 'ByteString' lines to a
-- stream of SMTP responses.  In case of a protocol error the
-- enumeration is aborted and an error is thrown.

smtpResponses :: forall b m. Monad m => Int -> Enumeratee ByteString SmtpResponse m b
smtpResponses maxMsgs =
    loop

    where
    loop :: Enumeratee ByteString SmtpResponse m b
    loop (Continue k) = do
        mResp <- evalMaybeT $ smtpResponse maxMsgs
        case mResp of
          Just resp -> k (Chunks [resp]) >>== loop
          Nothing   -> k EOF >>== loop
    loop step = return step


-- | Convert extension string to 'Extension' value, if the corresponding
-- extension is known.

stringToExtension :: ByteString -> Maybe Extension
stringToExtension _ = Nothing