{-# LANGUAGE RankNTypes #-} {-# LANGUAGE Rank2Types #-} module Data.Conduit.Cereal.Internal ( ConduitErrorHandler , SinkErrorHandler , SinkTerminationHandler , mkConduitGet , mkSinkGet ) where import Control.Monad (forever, when) import qualified Data.ByteString as BS import Data.Conduit (ConduitT, await, leftover, yield) import Data.Serialize hiding (get, put) -- | What should we do if the Get fails? type ConduitErrorHandler m o = String -> ConduitT BS.ByteString o m () type SinkErrorHandler m r = forall o. String -> ConduitT BS.ByteString o m r -- | What should we do if the stream is done before the Get is done? type SinkTerminationHandler m r = forall o. (BS.ByteString -> Result r) -> ConduitT BS.ByteString o m r -- | Construct a conduitGet with the specified 'ErrorHandler' mkConduitGet :: Monad m => ConduitErrorHandler m o -> Get o -> ConduitT BS.ByteString o m () mkConduitGet errorHandler get = consume True (runGetPartial get) [] BS.empty where pull f b s | BS.null s = await >>= maybe (when (not $ null b) (leftover $ BS.concat $ reverse b)) (pull f b) | otherwise = consume False f b s consume initial f b s = case f s of Fail msg _ -> do when (not $ null b) (leftover $ BS.concat $ reverse consumed) errorHandler msg Partial p -> pull p consumed BS.empty Done a s' -> case initial of -- this only works because the Get will either _always_ consume no input, or _never_ consume no input. True -> forever $ yield a False -> yield a >> pull (runGetPartial get) [] s' -- False -> yield a >> leftover s' >> mkConduitGet errorHandler get where consumed = s : b -- | Construct a sinkGet with the specified 'ErrorHandler' and 'TerminationHandler' mkSinkGet :: Monad m => SinkErrorHandler m r -> SinkTerminationHandler m r -> Get r -> ConduitT BS.ByteString o m r mkSinkGet errorHandler terminationHandler get = consume (runGetPartial get) [] BS.empty where pull f b s | BS.null s = await >>= \ x -> case x of Nothing -> when (not $ null b) (leftover $ BS.concat $ reverse b) >> terminationHandler f Just a -> pull f b a | otherwise = consume f b s consume f b s = case f s of Fail msg _ -> do when (not $ null b) (leftover $ BS.concat $ reverse consumed) errorHandler msg Partial p -> pull p consumed BS.empty Done r s' -> when (not $ BS.null s') (leftover s') >> return r where consumed = s : b