{-# LANGUAGE DeriveDataTypeable #-}
module Data.Serialize.Get.Enumerator
    ( ParseError(..)
    , iterGet
    ) where

import Prelude as P
import Control.Exception
import Data.Monoid
import Data.Typeable

import Data.Serialize.Get
import Data.Enumerator
import Data.ByteString as BS

data ParseError = ParseError String
                | EOFError
                deriving (Show, Typeable)

instance Exception ParseError

-- | Convert a 'Get' to an 'Iteratee'. The resulting 'Iteratee' emits a
-- 'ParseError' on failure.
iterGet :: Monad m => Get a -> Iteratee ByteString m a
iterGet = continue . step . runGetPartial
  where
    step p (Chunks xs) = loop p xs
    step p EOF = case p mempty of
        Done r _  -> yield r EOF
        Partial{} -> throwError EOFError
        Fail s    -> throwError (ParseError s)

    loop p [] = continue (step p)
    loop p (x:xs)
        | BS.null x = loop p xs
        | otherwise = case p x of
            Done r bs -> yield r $ Chunks (bs:xs)
            Partial c -> loop c xs
            Fail s    -> throwError (ParseError s)