{-# LANGUAGE Rank2Types #-}

module Data.Enumerator
    ( -- Enumerators
      bytesEnum,
      chunkEnum,
      partialSocketEnum,
      socketEnum,

      -- Combining enumerators
      compose
    ) where

import Control.Monad (liftM)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C (unpack)
import Data.Word (Word8)
import Network.Socket (Socket)
import Network.Socket.ByteString (recv)
import Numeric (readHex)

type IterateeM a m = a -> S.ByteString -> m (Either a a)
type EnumeratorM m = forall a. IterateeM a m -> a -> m a

-- -----------------------------------------------------------
-- Enumerators

-- | Enumerates a 'ByteString'.
bytesEnum :: Monad m => S.ByteString -> EnumeratorM m
bytesEnum bs f seed = do
  seed' <- f seed bs
  case seed' of
    Left seed''  -> return seed''
    Right seed'' -> return seed''


nl :: Word8
nl = 10

-- | Enumerates chunks of data encoded using HTTP chunked encoding.
chunkEnum :: Monad m => EnumeratorM m -> EnumeratorM m
chunkEnum enum f initSeed = fst `liftM` enum go (initSeed, Left S.empty)
    where
      go (seed, Left acc) bs =
        case S.elemIndex nl bs of
          Just ix -> let (line, rest) = S.splitAt (ix + 1) bs
                         hdr          = S.append acc line
                         chunkLen     = pHeader hdr
                     in case chunkLen of
                          Just n  -> go (seed, Right n) rest
                          Nothing -> error $ "malformed header" ++ show hdr
          Nothing -> return $ Right (seed, Left (S.append acc bs))
      go (seed, Right n) bs  =
        let len = S.length bs
        in if len < n
           then do
             seed' <- f seed bs
             case seed' of
               Right seed'' -> return $ Right (seed'', Right $! n - len)
               Left  seed'' -> return $ Left (seed'', Right $! n - len)
           else let (bs', rest) = S.splitAt n bs
                in do
                  seed' <- f seed bs'
                  case seed' of
                    Right seed'' -> go (seed'', Left S.empty) rest
                    Left  seed'' -> return $ Left (seed'', Left rest)

-- TODO: Ignore header.
pHeader :: S.ByteString -> Maybe Int
pHeader bs =
    case readHex $ C.unpack hdr of
      [(n, "")] -> Just n
      _         -> Nothing
    where
      hdr = S.take (S.length bs - 2) bs

-- | Maximum number of bytes sent or received in every socket
-- operation.
blockSize :: Int
blockSize = 4 * 1024

-- | @partialSocketEnum sock numBytes@ enumerates @numBytes@ bytes
-- received through the given socket.  Does not close the socket.
partialSocketEnum :: Socket -> Int -> EnumeratorM IO
partialSocketEnum sock numBytes f initSeed = go initSeed numBytes
  where
    go seed 0 = return seed
    go seed n = do
      bs <- recv sock blockSize
      if S.null bs
        then return seed
        else do
          seed' <- f seed bs
          case seed' of
            Right seed'' -> go seed'' $! n - S.length bs
            Left seed''  -> return seed''

-- | Enumerates data received through the given socket.  Does not
-- close the socket.
socketEnum :: Socket -> EnumeratorM IO
socketEnum sock f initSeed = go initSeed
  where
    go seed = do
      bs <- recv sock blockSize
      if S.null bs
        then return seed
        else do
          seed' <- f seed bs
          case seed' of
            Right seed'' -> go seed''
            Left seed''  -> return seed''

-- -----------------------------------------------------------
-- Combining enumerators

-- Make two enumerators behave like one.
compose :: Monad m => EnumeratorM m -> EnumeratorM m -> EnumeratorM m
compose enum1 enum2 f initSeed = enum1 f1 (Right initSeed) >>= k
    where
      f1 (Right seed) bs = do
        r <- f seed bs
        case r of
          x@(Right _) -> return $ Right x
          x           -> return $ Left x
      f1 x _              = return $ Left x  -- Cannot happen.
      k (Left seed)  = return seed
      k (Right seed) = enum2 f seed