{-# LANGUAGE DeriveDataTypeable #-}
module Data.Serialize.Get.Enumerator
    ( ParseError(..)
    , iterGet
    ) where

import Prelude as P
import Control.Exception
import Data.Monoid
import Data.Typeable

import Data.Serialize
import Data.Enumerator
import Data.ByteString as BS

newtype ParseError = ParseError String
                   deriving (Show, Typeable)

instance Exception ParseError

-- | Convert a 'Get' to an 'Iteratee'. The resulting 'Iteratee' may fail with a
-- 'ParseError'.
iterGet :: Monad m => Get a -> Iteratee ByteString m a
iterGet = continue . step . runGetPartial
  where

    step p (Chunks xs) = loop p (P.filter (not . BS.null) xs)
    step p EOF = case feed mempty (p mempty) of
        Done r _  -> yield r EOF
        Partial{} -> parseError "iterGet: divergent parser"
        Fail s    -> parseError s

    loop p [] = continue (step p)
    loop p (x:xs) = case p x of
        Done r bs -> yield r $ Chunks $ if BS.null bs then xs else bs:xs
        Partial k -> loop k xs
        Fail s    -> parseError s

    parseError = throwError . ParseError

    feed s (Partial k) = k s
    feed _ r = r