-- |
-- Module:     Data.Enumerator.NetLines
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  beta
--
-- Enumerator tools for working with text-based network protocols.

{-# LANGUAGE DeriveDataTypeable, ScopedTypeVariables #-}

module Data.Enumerator.NetLines
    ( -- * Iteratees
      netLine,
      netLineEmpty,
      netWord,
      netWordEmpty,

      -- * Enumeratees
      netLines,
      netLinesEmpty,
      netWords,
      netWordsEmpty,

      -- * General stream splitters
      netSplitBy,
      netSplitsBy,

      -- * Enumerators
      TimeoutError(..),
      enumHandleSession,
      enumHandleTimeout
    )
    where

import qualified Data.ByteString as B
import Control.ContStuff as Monad
import Control.Exception as Ex
import Data.ByteString (ByteString)
import Data.Enumerator as E
import Data.Time.Clock
import Data.Typeable
import Data.Word
import System.IO
import System.IO.Error as IOErr


-- | Exception for timed out IO operations.

newtype TimeoutError = TimeoutError { timeoutErrorMessage :: String }
    deriving (Typeable)

instance Ex.Exception TimeoutError

instance Show TimeoutError where
    show (TimeoutError msg) = "Operation timed out: " ++ msg


-- | Enumerate from a handle with the given buffer size (first
-- argument), read timeout in milliseconds (second argument) and session
-- timeout in milliseconds (third argument).  If either timeout is
-- exceeded a 'TimeoutError' exception is thrown via 'throwError'.

enumHandleSession ::
    forall b m. MonadIO m =>
    Int -> Int -> Int -> Handle -> Enumerator ByteString m b
enumHandleSession bufSize readTime sessionTime h step = do
    startTime <- liftIO getCurrentTime
    loop startTime step

    where
    loop :: UTCTime -> Enumerator ByteString m b
    loop startTime (Continue k) = do
        now <- liftIO getCurrentTime
        let timeoutErr = TimeoutError "Reading from handle"
            diff = sessionTime - round (1000 * diffUTCTime now startTime)
            timeout = min diff readTime
        when (timeout <= 0) $ throwError timeoutErr
        mHaveInput <- liftIO $ IOErr.try (hWaitForInput h timeout)
        case mHaveInput of
          Left err
              | isEOFError err -> continue k
              | otherwise      -> throwError err
          Right False -> throwError timeoutErr
          Right True  -> do
              mStr <- liftIO $ IOErr.try (B.hGetNonBlocking h bufSize)
              str <- either throwError return mStr
              if B.null str
                then continue k
                else k (Chunks [str]) >>== loop startTime
    loop _ step = returnI step


-- | Enumerate from a handle with the given buffer size (first argument)
-- and timeout in milliseconds (second argument).  If the timeout is
-- exceeded a 'TimeoutError' exception is thrown via 'throwError'.
--
-- Note that this timeout is not a timeout for the whole enumeration,
-- but for each individual read operation.  In other words, this timeout
-- protects against dead/unresponsive peers, but not against (perhaps
-- intentionally) slowly sending peers.

enumHandleTimeout ::
    forall b m. MonadIO m =>
    Int -> Int -> Handle -> Enumerator ByteString m b
enumHandleTimeout bufSize timeout h = loop
    where
    loop :: Enumerator ByteString m b
    loop (Continue k) = do
        mHaveInput <- liftIO $ IOErr.try (hWaitForInput h timeout)
        case mHaveInput of
          Left err
              | isEOFError err -> continue k
              | otherwise      -> throwError err
          Right False -> throwError $ TimeoutError "Reading from handle"
          Right True  -> do
              mStr <- liftIO $ IOErr.try (B.hGetNonBlocking h bufSize)
              str <- either throwError return mStr
              if B.null str
                then continue k
                else k (Chunks [str]) >>== loop
    loop step = returnI step


-- | Predicate for ASCII whitespace.

isSpace :: Word8 -> Bool
isSpace n = n == 32 || (n >= 9 && n <= 13)


-- | Get the next nonempty line from the stream using 'netLineEmpty'.

netLine :: forall m. Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netLine = nonEmpty . netLineEmpty


-- | Get the next line from the stream, length-limited by the given
-- 'Int'.  This iteratee is error-tolerant by using LF as the line
-- terminator and simply ignoring all CR characters.

netLineEmpty :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netLineEmpty = netSplitBy (== 10) (/= 13)


-- | Convert a raw byte stream to a stream of lines based on 'netLine'.

netLines :: Monad m => Int -> Enumeratee ByteString ByteString m b
netLines = netSplitsBy . netLine


-- | Convert a raw byte stream to a stream of lines based on
-- 'netLineEmpty'.

netLinesEmpty :: Monad m => Int -> Enumeratee ByteString ByteString m b
netLinesEmpty = netSplitsBy . netLineEmpty


-- | Get the next token, where tokens are splitted by the first given
-- predicate and filtered by the second.  Tokens are length-limited by
-- the given 'Int' and are truncated safely in constant space.

netSplitBy ::
    forall m. Monad m =>
    (Word8 -> Bool) -> (Word8 -> Bool) -> Int ->
    Iteratee ByteString m (Maybe ByteString)
netSplitBy breakP filterP n =
    continue (loop n B.empty)

    where
    loop :: Int -> ByteString -> Stream ByteString ->
            Iteratee ByteString m (Maybe ByteString)
    loop _ line' EOF = yield (if B.null line' then Nothing else Just line') EOF
    loop 0 line' (Chunks _) = continue (loop 0 line')
    loop n line' (Chunks []) = continue (loop n line')
    loop n' line' (Chunks (str:strs)) =
        if B.null line2'
          then line `seq` loop n line (Chunks strs)
          else yield (Just line) (Chunks (line2:strs))

        where
        (line1', line2') = B.break breakP str
        line1            = B.filter filterP line1'
        line2            = B.tail line2'
        line             = B.take n $ B.append line' line1
        n                = max 0 (n' - B.length line1)


-- | Split the stream using the supplied iteratee.

netSplitsBy ::
    forall b m. Monad m =>
    Iteratee ByteString m (Maybe ByteString) -> Enumeratee ByteString ByteString m b
netSplitsBy getLine = loop
    where
    loop :: Enumeratee ByteString ByteString m b
    loop (Continue k) = do
        mLine <- getLine
        case mLine of
          Just line -> k (Chunks [line]) >>== loop
          Nothing   -> k EOF >>== loop
    loop step = return step


-- | Get the next nonempty word from the stream with the given maximum
-- length.  Based on 'netWordEmpty'.

netWord :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netWord = nonEmpty . netWordEmpty


-- | Get the next word from the stream with the given maximum length.
-- This iteratee is error-tolerant by using ASCII whitespace as
-- splitting characters.

netWordEmpty :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netWordEmpty = netSplitBy isSpace (const True)


-- | Split the raw byte stream into words based on 'netWord'.

netWords :: Monad m => Int -> Enumeratee ByteString ByteString m b
netWords = netSplitsBy . netWord


-- | Split the raw byte stream into words based on 'netWords'.

netWordsEmpty :: Monad m => Int -> Enumeratee ByteString ByteString m b
netWordsEmpty = netSplitsBy . netWordEmpty


-- | Apply the given iteratee, until it returns a nonempty string.

nonEmpty ::
    forall a m. Monad m =>
    Iteratee a m (Maybe ByteString) -> Iteratee a m (Maybe ByteString)
nonEmpty getStr = evalMaybeT loop
    where
    loop :: MaybeT r (Iteratee a m) ByteString
    loop = do
        line <- liftF getStr
        if B.null line then loop else return line