-- |
-- Module:     Data.Enumerator.NetLines
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  beta
--
-- Enumerator tools for working with text-based network protocols.

{-# LANGUAGE ScopedTypeVariables #-}

module Data.Enumerator.NetLines
    ( -- * Iteratees
      netLine,
      netLineEmpty,
      netWord,
      netWordEmpty,

      -- * Enumeratees
      netLines,
      netLinesEmpty,
      netWords,
      netWordsEmpty,

      -- * General stream splitters
      netSplitBy,
      netSplitsBy,

      -- * Enumerators
      enumHandleTimeout
    )
    where

import qualified Data.ByteString as B
import Control.Arrow
import Control.ContStuff as Monad
import Data.ByteString (ByteString)
import Data.Enumerator as E
import Data.Word
import System.IO
import System.IO.Error as IOErr


-- | Enumerate from a handle with the given buffer size (first argument)
-- and timeout in milliseconds (second argument).  If the timeout is
-- exceeded an exception is thrown via 'throwError'.

enumHandleTimeout :: forall b m. MonadIO m =>
                     Int -> Int -> Handle -> Enumerator ByteString m b
enumHandleTimeout bufSize timeout h = loop
    where
    loop :: Enumerator ByteString m b
    loop (Continue k) = do
          mHaveInput <- liftIO $ IOErr.try (hWaitForInput h timeout)
          case mHaveInput of
            Left err
                | isEOFError err -> continue k
                | otherwise      -> throwError err
            Right False -> throwError $ userError "Handle timed out"
            Right True  -> do
                mStr <- liftIO $ IOErr.try (B.hGetNonBlocking h bufSize)
                str <- either throwError return mStr
                if B.null str
                  then continue k
                  else k (Chunks [str]) >>== loop
    loop step = returnI step


-- | Predicate for ASCII vertical whitespace.

isSpace :: Word8 -> Bool
isSpace n = n == 32 || (n >= 9 && n <= 13)


-- | Get the next nonempty line from the stream using 'netLineEmpty'.

netLine :: forall m. Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netLine = nonEmpty . netLineEmpty

-- | Get the next line from the stream, length-limited by the given
-- 'Int'.  This iteratee is error-tolerant by using LF as the line
-- terminator and simply ignoring all CR characters.

netLineEmpty :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netLineEmpty = netSplitBy (== 10) (/= 13)


-- | Convert a raw byte stream to a stream of lines based on 'netLine'.

netLines :: Monad m => Int -> Enumeratee ByteString ByteString m b
netLines = netSplitsBy . netLine


-- | Convert a raw byte stream to a stream of lines based on
-- 'netLineEmpty'.

netLinesEmpty :: Monad m => Int -> Enumeratee ByteString ByteString m b
netLinesEmpty = netSplitsBy . netLineEmpty


-- | Get the next token, where tokens are splitted by the first given
-- predicate and filtered by the second.  Tokens are length-limited by
-- the given 'Int' and are truncated safely in constant space.

netSplitBy ::
    forall m. Monad m =>
    (Word8 -> Bool) ->
    (Word8 -> Bool) ->
    Int ->
    Iteratee ByteString m (Maybe ByteString)
netSplitBy breakP filterP n =
    continue (loop n)

    where
    loop :: Int -> Stream ByteString -> Iteratee ByteString m (Maybe ByteString)
    loop _ EOF = return Nothing
    loop n (Chunks []) = continue (loop n)
    loop n (Chunks strs)
        | B.null str = continue (loop n)
        | otherwise = do
            case (B.null line1, B.null line2') of
              (True, True) -> yield Nothing EOF
              (False, True) ->
                  case n of
                    0 -> continue (loop 0)
                    _ -> seq pfx $
                         continue (loop (n - B.length pfx) >=>
                                   maybe (yield (Just pfx) EOF)
                                         (return . Just . B.append pfx))
              (_, False) -> yield (Just pfx) (Chunks [line2])

        where
        str             = B.concat strs
        (line1, line2') = B.break breakP str
        (pfx, _)        = first (B.filter filterP) $ B.splitAt n line1
        line2           = B.tail line2'


-- | Split the stream using the supplied iteratee.

netSplitsBy ::
    forall b m. Monad m =>
    Iteratee ByteString m (Maybe ByteString) -> Enumeratee ByteString ByteString m b
netSplitsBy getLine = loop
    where
    loop :: Enumeratee ByteString ByteString m b
    loop (Continue k) = do
        mLine <- getLine
        case mLine of
          Just line -> k (Chunks [line]) >>== loop
          Nothing   -> k EOF >>== loop
    loop step = return step


-- | Get the next nonempty word from the stream with the given maximum
-- length.  Based on 'netWordEmpty'.

netWord :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netWord = nonEmpty . netWordEmpty


-- | Get the next word from the stream with the given maximum length.
-- This iteratee is error-tolerant by using ASCII whitespace as
-- splitting characters.

netWordEmpty :: Monad m => Int -> Iteratee ByteString m (Maybe ByteString)
netWordEmpty = netSplitBy isSpace (const True)


-- | Split the raw byte stream into words based on 'netWord'.

netWords :: Monad m => Int -> Enumeratee ByteString ByteString m b
netWords = netSplitsBy . netWord


-- | Split the raw byte stream into words based on 'netWords'.

netWordsEmpty :: Monad m => Int -> Enumeratee ByteString ByteString m b
netWordsEmpty = netSplitsBy . netWordEmpty


-- | Apply the given iteratee, until it returns a nonempty line.

nonEmpty ::
    forall a m. Monad m =>
    Iteratee a m (Maybe ByteString) -> Iteratee a m (Maybe ByteString)
nonEmpty getStr = evalMaybeT loop
    where
    loop :: MaybeT r (Iteratee a m) ByteString
    loop = do
        line <- liftF getStr
        if B.null line then loop else return line