module Data.Conduit.Cereal.Internal
( ConduitErrorHandler
, SinkErrorHandler
, SinkTerminationHandler
, mkConduitGet
, mkSinkGet
) where
import Control.Monad (forever, when)
import qualified Data.ByteString as BS
import Data.Conduit (ConduitT, await, leftover, yield)
import Data.Serialize hiding (get, put)
type ConduitErrorHandler m o = String -> ConduitT BS.ByteString o m ()
type SinkErrorHandler m r = forall o. String -> ConduitT BS.ByteString o m r
type SinkTerminationHandler m r = forall o. (BS.ByteString -> Result r) -> ConduitT BS.ByteString o m r
mkConduitGet :: Monad m
=> ConduitErrorHandler m o
-> Get o
-> ConduitT BS.ByteString o m ()
mkConduitGet errorHandler get = consume True (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = await >>= maybe (when (not $ null b) (leftover $ BS.concat $ reverse b)) (pull f b)
| otherwise = consume False f b s
consume initial f b s = case f s of
Fail msg _ -> do
when (not $ null b) (leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done a s' -> case initial of
True -> forever $ yield a
False -> yield a >> pull (runGetPartial get) [] s'
where consumed = s : b
mkSinkGet :: Monad m
=> SinkErrorHandler m r
-> SinkTerminationHandler m r
-> Get r
-> ConduitT BS.ByteString o m r
mkSinkGet errorHandler terminationHandler get = consume (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = await >>= \ x -> case x of
Nothing -> when (not $ null b) (leftover $ BS.concat $ reverse b) >> terminationHandler f
Just a -> pull f b a
| otherwise = consume f b s
consume f b s = case f s of
Fail msg _ -> do
when (not $ null b) (leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done r s' -> when (not $ BS.null s') (leftover s') >> return r
where consumed = s : b