{-# LANGUAGE DeriveDataTypeable #-}
module Data.Serialize.Get.Enumerator
    ( ParseError(..)
    , iterGet
    ) where

import Prelude as P
import Control.Exception
import Data.Monoid
import Data.Typeable

import Data.Serialize
import Data.Enumerator
import Data.ByteString as BS

data ParseError = ParseError String
                deriving (Show, Typeable)

instance Exception ParseError

-- | Convert a 'Get' to an 'Iteratee'. The resulting 'Iteratee' emits a
-- 'ParseError' on failure.
iterGet :: Monad m => Get a -> Iteratee ByteString m a
iterGet = continue . step . runGetPartial
  where
    step p (Chunks xs) = loop p (P.filter (not . BS.null) xs)
    step p EOF = case feed mempty (p mempty) of
        Done r _  -> yield r EOF
        Partial{} -> parseError "iterGet: divergent parser"
        Fail s    -> parseError s

    loop p [] = continue (step p)
    loop p (x:xs) = case p x of
            Done r bs -> yield r $ Chunks $ if BS.null bs then xs else bs:xs
            Partial k -> loop k xs
            Fail s    -> parseError s

    parseError = throwError . ParseError

feed :: ByteString -> Result a -> Result a
feed s (Partial k) = k s
feed _ r = r