#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
#include "MachDeps.h"
#endif
module Data.Serialize.Get (
    
      Get
    , runGet
    , runGetLazy
    , runGetState
    , runGetLazyState
    , Result(..)
    , runGetPartial
    
    , ensure
    , isolate
    , label
    , skip
    , uncheckedSkip
    , lookAhead
    , lookAheadM
    , lookAheadE
    , uncheckedLookAhead
    
    , getBytes
    , remaining
    , isEmpty
    
    , getWord8
    
    , getByteString
    , getLazyByteString
    
    , getWord16be
    , getWord32be
    , getWord64be
    
    , getWord16le
    , getWord32le
    , getWord64le
    
    , getWordhost
    , getWord16host
    , getWord32host
    , getWord64host
    
    , getTwoOf
    , getListOf
    , getIArrayOf
    , getTreeOf
    , getSeqOf
    , getMapOf
    , getIntMapOf
    , getSetOf
    , getIntSetOf
    , getMaybeOf
    , getEitherOf
  ) where
import Control.Applicative (Applicative(..),Alternative(..))
import Control.Monad (unless,when,ap,MonadPlus(..),liftM2)
import Data.Array.IArray (IArray,listArray)
import Data.Ix (Ix)
import Data.List (intercalate)
import Data.Maybe (isNothing,fromMaybe)
import Foreign
import qualified Data.ByteString          as B
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe   as B
import qualified Data.ByteString.Lazy     as L
import qualified Data.IntMap              as IntMap
import qualified Data.IntSet              as IntSet
import qualified Data.Map                 as Map
import qualified Data.Sequence            as Seq
import qualified Data.Set                 as Set
import qualified Data.Tree                as T
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
import GHC.Base
import GHC.Word
#endif
data Result r = Fail String
              
              
              | Partial (B.ByteString -> Result r)
              
              
              
              | Done r B.ByteString
              
              
              
instance Show r => Show (Result r) where
    show (Fail msg)  = "Fail " ++ show msg
    show (Partial _) = "Partial _"
    show (Done r bs) = "Done " ++ show r ++ " " ++ show bs
instance Functor Result where
    fmap _ (Fail msg)  = Fail msg
    fmap f (Partial k) = Partial (fmap f . k)
    fmap f (Done r bs) = Done (f r) bs
newtype Get a = Get
  { unGet :: forall r. Input -> Buffer -> More
                    -> Failure r -> Success a r
                    -> Result r }
type Input  = B.ByteString
type Buffer = Maybe B.ByteString
append :: Buffer -> Buffer -> Buffer
append l r = B.append `fmap` l <*> r
bufferBytes :: Buffer -> B.ByteString
bufferBytes  = fromMaybe B.empty
type Failure   r = Input -> Buffer -> More -> [String] -> String -> Result r
type Success a r = Input -> Buffer -> More -> a                  -> Result r
data More
  = Complete
  | Incomplete (Maybe Int)
    deriving (Eq)
moreLength :: More -> Int
moreLength m = case m of
  Complete      -> 0
  Incomplete mb -> fromMaybe 0 mb
instance Functor Get where
    fmap p m =
      Get $ \s0 b0 m0 kf ks ->
        let ks' s1 b1 m1 a = ks s1 b1 m1 (p a)
         in unGet m s0 b0 m0 kf ks'
instance Applicative Get where
    pure  = return
    (<*>) = ap
instance Alternative Get where
    empty = failDesc "empty"
    (<|>) = mplus
instance Monad Get where
    return a = Get $ \ s0 b0 m0 _ ks -> ks s0 b0 m0 a
    m >>= g  =
      Get $ \s0 b0 m0 kf ks ->
        let ks' s1 b1 m1 a = unGet (g a) s1 b1 m1 kf ks
         in unGet m s0 b0 m0 kf ks'
    fail     = failDesc
instance MonadPlus Get where
    mzero     = failDesc "mzero"
    mplus a b =
      Get $ \s0 b0 m0 kf ks ->
        let kf' _ b1 m1 _ _ = unGet b (s0 `B.append` bufferBytes b1)
                                      (b0 `append` b1) m1 kf ks
         in unGet a s0 (Just B.empty) m0 kf' ks
formatTrace :: [String] -> String
formatTrace [] = "Empty call stack"
formatTrace ls = "From:\t" ++ intercalate "\n\t" ls ++ "\n"
get :: Get B.ByteString
get  = Get (\s0 b0 m0 _ k -> k s0 b0 m0 s0)
put :: B.ByteString -> Get ()
put s = Get (\_ b0 m _ k -> k s b0 m ())
label :: String -> Get a -> Get a
label l m =
  Get $ \ s0 b0 m0 kf ks ->
    let kf' s1 b1 m1 ls = kf s1 b1 m1 (l:ls)
     in unGet m s0 b0 m0 kf' ks
finalK :: Success a a
finalK s _ _ a = Done a s
failK :: Failure a
failK _ _ _ ls s = Fail (unlines [s, formatTrace ls])
runGet :: Get a -> B.ByteString -> Either String a
runGet m str =
  case unGet m str Nothing Complete failK finalK of
    Fail i    -> Left i
    Done a _  -> Right a
    Partial{} -> Left "Failed reading: Internal error: unexpected Partial."
runGetPartial :: Get a -> B.ByteString -> Result a
runGetPartial m str =
  unGet m str Nothing (Incomplete Nothing) failK finalK
runGetState :: Get a -> B.ByteString -> Int
            -> Either String (a, B.ByteString)
runGetState m str off =
    case unGet m (B.drop off str) Nothing Complete failK finalK of
      Fail i      -> Left i
      Done a bs   -> Right (a, bs)
      Partial{}   -> Left "Failed reading: Internal error: unexpected Partial."
runGetLazy' :: Get a -> L.ByteString -> (Either String a,L.ByteString)
runGetLazy' m lstr = loop run (L.toChunks lstr)
  where
  remLen c = fromIntegral (L.length lstr)  B.length c
  run str  = unGet m str Nothing (Incomplete (Just (remLen str))) failK finalK
  loop _ []     =
    (Left "Failed reading: Internal error: unexpected end of input",L.empty)
  loop k (c:cs) = case k c of
    Fail str   -> (Left str,L.empty)
    Partial k' -> loop k' cs
    Done r c'  -> (Right r,L.fromChunks (c':cs))
runGetLazy :: Get a -> L.ByteString -> Either String a
runGetLazy m lstr = fst (runGetLazy' m lstr)
runGetLazyState :: Get a -> L.ByteString -> Either String (a,L.ByteString)
runGetLazyState m lstr = case runGetLazy' m lstr of
  (Right a,rest) -> Right (a,rest)
  (Left err,_)   -> Left err
ensure :: Int -> Get B.ByteString
ensure n = n `seq` Get $ \ s0 b0 m0 kf ks ->
    if B.length s0 >= n
    then ks s0 b0 m0 s0
    else unGet (demandInput >> ensureRec n) s0 b0 m0 kf ks
ensureRec :: Int -> Get B.ByteString
ensureRec n = Get $ \s0 b0 m0 kf ks ->
    if B.length s0 >= n
    then ks s0 b0 m0 s0
    else unGet (demandInput >> ensureRec n) s0 b0 m0 kf ks
isolate :: Int -> Get a -> Get a
isolate n m = do
  when (n < 0) (fail "Attempted to isolate a negative number of bytes")
  s <- ensure n
  let (s',rest) = B.splitAt n s
  put s'
  a    <- m
  used <- get
  unless (B.null used) (fail "not all bytes parsed in isolate")
  put rest
  return a
demandInput :: Get ()
demandInput = Get $ \s0 b0 m0 kf ks ->
  case m0 of
    Complete      -> kf s0 b0 m0 ["demandInput"] "too few bytes"
    Incomplete mb -> Partial $ \s ->
      if B.null s
      then kf s0 b0 m0 ["demandInput"] "too few bytes"
      else let update l = l  B.length s
               s1 = s0 `B.append` s
               b1 = b0 `append` Just s
            in ks s1 b1 (Incomplete (update `fmap` mb)) ()
failDesc :: String -> Get a
failDesc err = do
    let msg = "Failed reading: " ++ err
    Get (\s0 b0 m0 kf _ -> kf s0 b0 m0 [] msg)
skip :: Int -> Get ()
skip n = do
  s <- ensure n
  put (B.drop n s)
uncheckedSkip :: Int -> Get ()
uncheckedSkip n = do
    s <- get
    put (B.drop n s)
lookAhead :: Get a -> Get a
lookAhead ga = Get $ \ s0 b0 m0 kf ks ->
  let ks' s1 b1 = ks (s0 `B.append` bufferBytes b1) (b0 `append` b1)
   in unGet ga s0 (Just B.empty) m0 kf ks'
lookAheadM :: Get (Maybe a) -> Get (Maybe a)
lookAheadM gma = do
    s <- get
    ma <- gma
    when (isNothing ma) (put s)
    return ma
lookAheadE :: Get (Either a b) -> Get (Either a b)
lookAheadE gea = do
    s <- get
    ea <- gea
    case ea of
        Left _ -> put s
        _      -> return ()
    return ea
uncheckedLookAhead :: Int -> Get B.ByteString
uncheckedLookAhead n = do
    s <- get
    return (B.take n s)
remaining :: Get Int
remaining = Get (\ s0 b0 m0 _ ks -> ks s0 b0 m0 (B.length s0 + moreLength m0))
isEmpty :: Get Bool
isEmpty = Get (\ s0 b0 m0 _ ks -> ks s0 b0 m0 (B.null s0 && moreLength m0 == 0))
getByteString :: Int -> Get B.ByteString
getByteString n = do
  bs <- getBytes n
  return $! B.copy bs
getLazyByteString :: Int64 -> Get L.ByteString
getLazyByteString n = f `fmap` getByteString (fromIntegral n)
  where f bs = L.fromChunks [bs]
getBytes :: Int -> Get B.ByteString
getBytes n = do
    s <- ensure n
    let consume = B.unsafeTake n s
        rest    = B.unsafeDrop n s
        
    put rest
    return consume
getPtr :: Storable a => Int -> Get a
getPtr n = do
    (fp,o,_) <- B.toForeignPtr `fmap` getBytes n
    let k p = peek (castPtr (p `plusPtr` o))
    return (B.inlinePerformIO (withForeignPtr fp k))
getWord8 :: Get Word8
getWord8 = getPtr (sizeOf (undefined :: Word8))
getWord16be :: Get Word16
getWord16be = do
    s <- getBytes 2
    return $! (fromIntegral (s `B.index` 0) `shiftl_w16` 8) .|.
              (fromIntegral (s `B.index` 1))
getWord16le :: Get Word16
getWord16le = do
    s <- getBytes 2
    return $! (fromIntegral (s `B.index` 1) `shiftl_w16` 8) .|.
              (fromIntegral (s `B.index` 0) )
getWord32be :: Get Word32
getWord32be = do
    s <- getBytes 4
    return $! (fromIntegral (s `B.index` 0) `shiftl_w32` 24) .|.
              (fromIntegral (s `B.index` 1) `shiftl_w32` 16) .|.
              (fromIntegral (s `B.index` 2) `shiftl_w32`  8) .|.
              (fromIntegral (s `B.index` 3) )
getWord32le :: Get Word32
getWord32le = do
    s <- getBytes 4
    return $! (fromIntegral (s `B.index` 3) `shiftl_w32` 24) .|.
              (fromIntegral (s `B.index` 2) `shiftl_w32` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftl_w32`  8) .|.
              (fromIntegral (s `B.index` 0) )
getWord64be :: Get Word64
getWord64be = do
    s <- getBytes 8
    return $! (fromIntegral (s `B.index` 0) `shiftl_w64` 56) .|.
              (fromIntegral (s `B.index` 1) `shiftl_w64` 48) .|.
              (fromIntegral (s `B.index` 2) `shiftl_w64` 40) .|.
              (fromIntegral (s `B.index` 3) `shiftl_w64` 32) .|.
              (fromIntegral (s `B.index` 4) `shiftl_w64` 24) .|.
              (fromIntegral (s `B.index` 5) `shiftl_w64` 16) .|.
              (fromIntegral (s `B.index` 6) `shiftl_w64`  8) .|.
              (fromIntegral (s `B.index` 7) )
getWord64le :: Get Word64
getWord64le = do
    s <- getBytes 8
    return $! (fromIntegral (s `B.index` 7) `shiftl_w64` 56) .|.
              (fromIntegral (s `B.index` 6) `shiftl_w64` 48) .|.
              (fromIntegral (s `B.index` 5) `shiftl_w64` 40) .|.
              (fromIntegral (s `B.index` 4) `shiftl_w64` 32) .|.
              (fromIntegral (s `B.index` 3) `shiftl_w64` 24) .|.
              (fromIntegral (s `B.index` 2) `shiftl_w64` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftl_w64`  8) .|.
              (fromIntegral (s `B.index` 0) )
getWordhost :: Get Word
getWordhost = getPtr (sizeOf (undefined :: Word))
getWord16host :: Get Word16
getWord16host = getPtr (sizeOf (undefined :: Word16))
getWord32host :: Get Word32
getWord32host = getPtr  (sizeOf (undefined :: Word32))
getWord64host   :: Get Word64
getWord64host = getPtr  (sizeOf (undefined :: Word64))
shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
shiftl_w16 (W16# w) (I# i) = W16# (w `uncheckedShiftL#`   i)
shiftl_w32 (W32# w) (I# i) = W32# (w `uncheckedShiftL#`   i)
#if WORD_SIZE_IN_BITS < 64
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL64#` i)
#if __GLASGOW_HASKELL__ <= 606
foreign import ccall unsafe "stg_uncheckedShiftL64"
    uncheckedShiftL64#     :: Word64# -> Int# -> Word64#
#endif
#else
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL#` i)
#endif
#else
shiftl_w16 = shiftL
shiftl_w32 = shiftL
shiftl_w64 = shiftL
#endif
getTwoOf :: Get a -> Get b -> Get (a,b)
getTwoOf ma mb = liftM2 (,) ma mb
getListOf :: Get a -> Get [a]
getListOf m = go [] =<< getWord64be
  where
  go as 0 = return (reverse as)
  go as i = do x <- m
               x `seq` go (x:as) (i  1)
getIArrayOf :: (Ix i, IArray a e) => Get i -> Get e -> Get (a i e)
getIArrayOf ix e = liftM2 listArray (getTwoOf ix ix) (getListOf e)
getSeqOf :: Get a -> Get (Seq.Seq a)
getSeqOf m = go Seq.empty =<< getWord64be
  where
  go xs 0 = return $! xs
  go xs n = xs `seq` n `seq` do
              x <- m
              go (xs Seq.|> x) (n  1)
getTreeOf :: Get a -> Get (T.Tree a)
getTreeOf m = liftM2 T.Node m (getListOf (getTreeOf m))
getMapOf :: Ord k => Get k -> Get a -> Get (Map.Map k a)
getMapOf k m = Map.fromDistinctAscList `fmap` getListOf (getTwoOf k m)
getIntMapOf :: Get Int -> Get a -> Get (IntMap.IntMap a)
getIntMapOf i m = IntMap.fromDistinctAscList `fmap` getListOf (getTwoOf i m)
getSetOf :: Ord a => Get a -> Get (Set.Set a)
getSetOf m = Set.fromDistinctAscList `fmap` getListOf m
getIntSetOf :: Get Int -> Get IntSet.IntSet
getIntSetOf m = IntSet.fromDistinctAscList `fmap` getListOf m
getMaybeOf :: Get a -> Get (Maybe a)
getMaybeOf m = do
  tag <- getWord8
  case tag of
    0 -> return Nothing
    _ -> Just `fmap` m
getEitherOf :: Get a -> Get b -> Get (Either a b)
getEitherOf ma mb = do
  tag <- getWord8
  case tag of
    0 -> Left  `fmap` ma
    _ -> Right `fmap` mb