module Text.ProtocolBuffers.Get
(Get,runGet,Result(..)
,ensureBytes,getStorable,getLazyByteString,suspendUntilComplete
,getAvailable,putAvailable
,lookAhead,lookAheadM,lookAheadE
,skip,bytesRead,isEmpty,isReallyEmpty,remaining,spanOf
,getWord8,getByteString
,getWord16be,getWord32be,getWord64be
,getWord16le,getWord32le,getWord64le
,getWordhost,getWord16host,getWord32host,getWord64host
) where
import Control.Applicative(Applicative(pure,(<*>)),Alternative(empty,(<|>)))
import Control.Monad(MonadPlus(mzero,mplus),when)
import Control.Monad.Error.Class(MonadError(throwError,catchError),Error(strMsg))
import Control.Monad(ap)
import Data.Bits(Bits((.|.)))
import qualified Data.ByteString as S(concat,length,null,splitAt)
import qualified Data.ByteString.Internal as S(ByteString,toForeignPtr,inlinePerformIO)
import qualified Data.ByteString.Unsafe as S(unsafeIndex)
import qualified Data.ByteString.Lazy as L(take,drop,length,span,toChunks,fromChunks,null)
import qualified Data.ByteString.Lazy.Internal as L(ByteString(..),chunk)
import qualified Data.Foldable as F(foldr,foldr1)
import Data.Int(Int64)
import Data.Monoid(Monoid(mempty,mappend))
import Data.Sequence(Seq,null,(|>))
import Data.Word(Word,Word8,Word16,Word32,Word64)
import Foreign.ForeignPtr(withForeignPtr)
import Foreign.Ptr(castPtr,plusPtr)
import Foreign.Storable(Storable(peek,sizeOf))
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
import GHC.Base(Int(..),uncheckedShiftL#)
import GHC.Word(Word16(..),Word32(..),Word64(..),uncheckedShiftL64#)
#endif
data Result a = Failed !Int64 String
| Finished !L.ByteString !Int64 a
| Partial (Maybe L.ByteString -> Result a)
data S = S { top :: !S.ByteString
, current :: !L.ByteString
, consumed :: !Int64
} deriving Show
data FrameStack b = ErrorFrame (String -> S -> Result b)
Bool
| HandlerFrame (Maybe ( S -> FrameStack b -> String -> Result b ))
S
(Seq L.ByteString)
(FrameStack b)
type Success b a = (a -> S -> FrameStack b -> Result b)
newtype Get a = Get {
unGet :: forall b.
Success b a
-> S
-> FrameStack b
-> Result b
}
setCheckpoint,useCheckpoint,clearCheckpoint :: Get ()
setCheckpoint = Get $ \ sc s pc -> sc () s (HandlerFrame Nothing s mempty pc)
useCheckpoint = Get $ \ sc (S _ _ _) (HandlerFrame Nothing s future pc) ->
let (S {top=ss, current=bs, consumed=n}) = collect s future
in sc () (S ss bs n) pc
clearCheckpoint = Get $ \ sc s (HandlerFrame Nothing _s _future pc) -> sc () s pc
lookAhead :: Get a -> Get a
lookAhead todo = do
setCheckpoint
a <- todo
useCheckpoint
return a
lookAheadM :: Get (Maybe a) -> Get (Maybe a)
lookAheadM todo = do
setCheckpoint
a <- todo
maybe useCheckpoint (const clearCheckpoint) a
return a
lookAheadE :: Get (Either a b) -> Get (Either a b)
lookAheadE todo = do
setCheckpoint
a <- todo
either (const useCheckpoint) (const clearCheckpoint) a
return a
collect :: S -> Seq L.ByteString -> S
collect s@(S ss bs n) future | Data.Sequence.null future = s
| otherwise = S ss (mappend bs (F.foldr1 mappend future)) n
instance (Show a) => Show (Result a) where
showsPrec _ (Failed n msg) = ("(Failed "++) . shows n . (' ':) . shows msg . (")"++)
showsPrec _ (Finished bs n a) =
("(CFinished ("++)
. shows bs . (") ("++)
. shows n . (") ("++)
. shows a . ("))"++)
showsPrec _ (Partial {}) = ("(Partial <Maybe Data.ByteString.Lazy.ByteString-> Result a)"++)
instance Show (FrameStack b) where
showsPrec _ (ErrorFrame _ p) =(++) "(ErrorFrame <e->s->m b> " . shows p . (")"++)
showsPrec _ (HandlerFrame _ s future pc) = ("(HandlerFrame <> ("++)
. shows s . (") ("++) . shows future . (") ("++)
. shows pc . (")"++)
runGet :: Get a -> L.ByteString -> Result a
runGet (Get f) bsIn = f scIn sIn (ErrorFrame ec True)
where scIn a (S ss bs n) _pc = Finished (L.chunk ss bs) n a
sIn = case bsIn of L.Empty -> S mempty mempty 0
L.Chunk ss bs -> S ss bs 0
ec msg sOut = Failed (consumed sOut) msg
getAvailable :: Get L.ByteString
getAvailable = Get $ \ sc s@(S ss bs _) pc -> sc (L.chunk ss bs) s pc
putAvailable :: L.ByteString -> Get ()
putAvailable bsNew = Get $ \ sc (S _ss _bs n) pc ->
let s' = case bsNew of
L.Empty -> S mempty mempty n
L.Chunk ss' bs' -> S ss' bs' n
rebuild (HandlerFrame catcher (S ss1 bs1 n1) future pc') =
HandlerFrame catcher sNew mempty (rebuild pc')
where balance = n n1
whole | balance < 0 = error "Impossible? Cannot rebuild HandlerFrame in MyGet.putAvailable: balance is negative!"
| otherwise = L.take balance $ L.chunk ss1 bs1 `mappend` F.foldr mappend mempty future
sNew | balance /= L.length whole = error "Impossible? MyGet.putAvailable.rebuild.sNew HandlerFrame assertion failed."
| otherwise = case mappend whole bsNew of
L.Empty -> S mempty mempty n1
L.Chunk ss2 bs2 -> S ss2 bs2 n1
rebuild x@(ErrorFrame {}) = x
in sc () s' (rebuild pc)
getFull :: Get S
getFull = Get $ \ sc s pc -> sc s s pc
putFull :: S -> Get ()
putFull s = Get $ \ sc _s pc -> sc () s pc
suspendUntilComplete :: Get ()
suspendUntilComplete = do
continue <- suspend
when continue suspendUntilComplete
suspendMsg :: String -> Get ()
suspendMsg msg = do continue <- suspend
if continue then return ()
else throwError msg
ensureBytes :: Int64 -> Get ()
ensureBytes n = do
(S ss bs _read) <- getFull
if n < fromIntegral (S.length ss)
then return ()
else do if n == L.length (L.take n (L.chunk ss bs))
then return ()
else suspendMsg "ensureBytes failed" >> ensureBytes n
getLazyByteString :: Int64 -> Get L.ByteString
getLazyByteString n | n<=0 = return mempty
| otherwise = do
(S ss bs offset) <- getFull
case splitAtOrDie n (L.chunk ss bs) of
Just (consume,rest) ->do
case rest of
L.Empty -> putFull (S mempty mempty (offset + n))
L.Chunk ss' bs' -> putFull (S ss' bs' (offset + n))
return consume
Nothing -> suspendMsg "getLazyByteString failed" >> getLazyByteString n
class MonadSuspend m where
suspend :: m Bool
instance MonadSuspend Get where
suspend = Get $ \ sc sIn pcIn ->
if checkBool pcIn
then let f Nothing = let pcOut = rememberFalse pcIn
in sc False sIn pcOut
f (Just bs') = let sOut = appendBS sIn bs'
pcOut = addFuture bs' pcIn
in sc True sOut pcOut
in Partial f
else sc False sIn pcIn
where appendBS (S ss bs n) bs' = S ss (mappend bs bs') n
addFuture bs (HandlerFrame catcher s future pc) =
HandlerFrame catcher s (future |> bs) (addFuture bs pc)
addFuture _bs x@(ErrorFrame {}) = x
checkBool (ErrorFrame _ b) = b
checkBool (HandlerFrame _ _ _ pc) = checkBool pc
rememberFalse (ErrorFrame ec _) = ErrorFrame ec False
rememberFalse (HandlerFrame catcher s future pc) =
HandlerFrame catcher s future (rememberFalse pc)
discardInnerHandler :: Get ()
discardInnerHandler = Get $ \ sc s pcIn ->
let pcOut = case pcIn of ErrorFrame {} -> pcIn
HandlerFrame _ _ _ pc' -> pc'
in sc () s pcOut
skip :: Int64 -> Get ()
skip m | m <=0 = return ()
| otherwise = do
ensureBytes m
(S ss bs n) <- getFull
case L.drop m (L.chunk ss bs) of
L.Empty -> putFull (S mempty mempty (n+m))
L.Chunk ss' bs' -> putFull (S ss' bs' (n+m))
bytesRead :: Get Int64
bytesRead = fmap consumed getFull
remaining :: Get Int64
remaining = do (S ss bs _) <- getFull
return $ fromIntegral (S.length ss) + (L.length bs)
isEmpty :: Get Bool
isEmpty = do (S ss bs _n) <- getFull
return $ (S.null ss) && (L.null bs)
isReallyEmpty :: Get Bool
isReallyEmpty = do
b <- isEmpty
if b then loop
else return b
where loop = do
continue <- suspend
if continue
then do b <- isEmpty
if b then loop
else return b
else return True
spanOf :: (Word8 -> Bool) -> Get (L.ByteString)
spanOf f = do let loop = do (S ss bs n) <- getFull
let (pre,post) = L.span f (L.chunk ss bs)
case post of
L.Empty -> putFull (S mempty mempty (n + L.length pre))
L.Chunk ss' bs' -> putFull (S ss' bs' (n + L.length pre))
if L.null post
then fmap ((L.toChunks pre)++) $ do
continue <- suspend
if continue then loop
else return (L.toChunks pre)
else return (L.toChunks pre)
fmap L.fromChunks loop
getByteString :: Int -> Get S.ByteString
getByteString nIn | nIn <= 0 = return mempty
| otherwise = do
(S ss bs n) <- getFull
if nIn < S.length ss
then do let (pre,post) = S.splitAt nIn ss
putFull (S post bs (n+fromIntegral nIn))
return pre
else fmap (S.concat . L.toChunks) (getLazyByteString (fromIntegral nIn))
getWordhost :: Get Word
getWordhost = getStorable
getWord8 :: Get Word8
getWord8 = getPtr 1
getWord16be,getWord16le,getWord16host :: Get Word16
getWord16be = do
s <- getByteString 2
return $! (fromIntegral (s `S.unsafeIndex` 0) `shiftl_w16` 8) .|.
(fromIntegral (s `S.unsafeIndex` 1))
getWord16le = do
s <- getByteString 2
return $! (fromIntegral (s `S.unsafeIndex` 1) `shiftl_w16` 8) .|.
(fromIntegral (s `S.unsafeIndex` 0) )
getWord16host = getStorable
getWord32be,getWord32le,getWord32host :: Get Word32
getWord32be = do
s <- getByteString 4
return $! (fromIntegral (s `S.unsafeIndex` 0) `shiftl_w32` 24) .|.
(fromIntegral (s `S.unsafeIndex` 1) `shiftl_w32` 16) .|.
(fromIntegral (s `S.unsafeIndex` 2) `shiftl_w32` 8) .|.
(fromIntegral (s `S.unsafeIndex` 3) )
getWord32le = do
s <- getByteString 4
return $! (fromIntegral (s `S.unsafeIndex` 3) `shiftl_w32` 24) .|.
(fromIntegral (s `S.unsafeIndex` 2) `shiftl_w32` 16) .|.
(fromIntegral (s `S.unsafeIndex` 1) `shiftl_w32` 8) .|.
(fromIntegral (s `S.unsafeIndex` 0) )
getWord32host = getStorable
getWord64be,getWord64le,getWord64host :: Get Word64
getWord64be = do
s <- getByteString 8
return $! (fromIntegral (s `S.unsafeIndex` 0) `shiftl_w64` 56) .|.
(fromIntegral (s `S.unsafeIndex` 1) `shiftl_w64` 48) .|.
(fromIntegral (s `S.unsafeIndex` 2) `shiftl_w64` 40) .|.
(fromIntegral (s `S.unsafeIndex` 3) `shiftl_w64` 32) .|.
(fromIntegral (s `S.unsafeIndex` 4) `shiftl_w64` 24) .|.
(fromIntegral (s `S.unsafeIndex` 5) `shiftl_w64` 16) .|.
(fromIntegral (s `S.unsafeIndex` 6) `shiftl_w64` 8) .|.
(fromIntegral (s `S.unsafeIndex` 7) )
getWord64le = do
s <- getByteString 8
return $! (fromIntegral (s `S.unsafeIndex` 7) `shiftl_w64` 56) .|.
(fromIntegral (s `S.unsafeIndex` 6) `shiftl_w64` 48) .|.
(fromIntegral (s `S.unsafeIndex` 5) `shiftl_w64` 40) .|.
(fromIntegral (s `S.unsafeIndex` 4) `shiftl_w64` 32) .|.
(fromIntegral (s `S.unsafeIndex` 3) `shiftl_w64` 24) .|.
(fromIntegral (s `S.unsafeIndex` 2) `shiftl_w64` 16) .|.
(fromIntegral (s `S.unsafeIndex` 1) `shiftl_w64` 8) .|.
(fromIntegral (s `S.unsafeIndex` 0) )
getWord64host = getStorable
instance Functor Get where
fmap f m = Get (\sc -> unGet m (sc . f))
instance Monad Get where
return a = Get (\sc -> sc a)
m >>= k = Get (\sc -> unGet m (\a -> unGet (k a) sc))
fail = throwError . strMsg
instance MonadError String Get where
throwError msg = Get $ \_sc s pcIn ->
let go (ErrorFrame ec _) = ec msg s
go (HandlerFrame (Just catcher) s1 future pc1) = catcher (collect s1 future) pc1 msg
go (HandlerFrame Nothing _s1 _future pc1) = go pc1
in go pcIn
catchError mayFail handler = Get $ \sc s pc ->
let pcWithHandler = let catcher s1 pc1 e1 = unGet (handler e1) sc s1 pc1
in HandlerFrame (Just catcher) s mempty pc
actionWithCleanup = mayFail >>= \a -> discardInnerHandler >> return a
in unGet actionWithCleanup sc s pcWithHandler
instance MonadPlus Get where
mzero = throwError (strMsg "[mzero:no message]")
mplus m1 m2 = catchError m1 (const m2)
instance Applicative Get where
pure = return
(<*>) = ap
instance Alternative Get where
empty = mzero
(<|>) = mplus
splitAtOrDie :: Int64 -> L.ByteString -> Maybe (L.ByteString, L.ByteString)
splitAtOrDie i ps | i <= 0 = Just (L.Empty, ps)
splitAtOrDie _i L.Empty = Nothing
splitAtOrDie i (L.Chunk x xs) | i < len = let (pre,post) = S.splitAt (fromIntegral i) x
in Just (L.Chunk pre L.Empty
,L.Chunk post xs)
| otherwise = case splitAtOrDie (ilen) xs of
Nothing -> Nothing
Just (y1,y2) -> Just (L.Chunk x y1,y2)
where len = fromIntegral (S.length x)
getPtr :: (Storable a) => Int -> Get a
getPtr n = do
(fp,o,_) <- fmap S.toForeignPtr (getByteString n)
return . S.inlinePerformIO $ withForeignPtr fp $ \p -> peek (castPtr $ p `plusPtr` o)
getStorable :: forall a. (Storable a) => Get a
getStorable = do
(fp,o,_) <- fmap S.toForeignPtr (getByteString (sizeOf (undefined :: a)))
return . S.inlinePerformIO $ withForeignPtr fp $ \p -> peek (castPtr $ p `plusPtr` o)
shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
shiftl_w16 (W16# w) (I# i) = W16# (w `uncheckedShiftL#` i)
shiftl_w32 (W32# w) (I# i) = W32# (w `uncheckedShiftL#` i)
#if WORD_SIZE_IN_BITS < 64
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL64#` i)
#if __GLASGOW_HASKELL__ <= 606
foreign import ccall unsafe "stg_uncheckedShiftL64"
uncheckedShiftL64# :: Word64# -> Int# -> Word64#
#endif
#else
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL#` i)
#endif
#else
shiftl_w16 = shiftL
shiftl_w32 = shiftL
shiftl_w64 = shiftL
#endif