-----------------------------------------------------------------------------
-- |
-- Module      : Data.ByteString.Parser
-- Copyright   : Lennart Kolmodin, George Giorgidze
-- License     : BSD3
--
-- Maintainer  : George Giorgidze <http://cs.nott.ac.uk/~ggg/>
-- Stability   : experimental
-- Portability : Portable
--
-- A monad for efficiently building structures from
-- encoded lazy ByteStrings.
--
-----------------------------------------------------------------------------

module Codec.ByteString.Parser (

    -- * The Parser type
      Parser
    , runParser
    , runParserState

    -- * Parsing
    , choice
    , expect
    , skip
    , lookAhead
    , lookAheadM
    , lookAheadE

    -- * Utility
    , bytesRead
    , getBytes
    , remaining
    , isEmpty

    -- * Parsing particular types
    , satisfy
    , getString
    , getStringNul
    , string
    , getWord8
    , getInt8
    , word8
    , int8

    -- ** ByteStrings
    , getByteString
    , getLazyByteString
    , getLazyByteStringNul
    , getRemainingLazyByteString

    -- ** Big-endian reads
    , getWord16be
    , word16be
    , getWord24be
    , word24be
    , getWord32be
    , word32be
    , getWord64be
    , word64be

    , getInt16be
    , int16be
    , getInt32be
    , int32be
    , getInt64be
    , int64be

    -- ** Little-endian reads
    , getWord16le
    , word16le
    , getWord24le
    , word24le
    , getWord32le
    , word32le
    , getWord64le
    , word64le

    , getInt16le
    , int16le
    , getInt32le
    , int32le
    , getInt64le
    , int64le

    -- ** Host-endian, unaligned reads
    , getWordHost
    , wordHost
    , getWord16host
    , word16host
    , getWord32host
    , word32host
    , getWord64host
    , word64host

    -- Variable length reads
    , getVarLenBe
    , varLenBe
    , getVarLenLe
    , varLenLe
  ) where

import qualified Data.ByteString               as B
import qualified Data.ByteString.Lazy          as L
import qualified Data.ByteString.Internal      as B
import qualified Data.ByteString.Lazy.Internal as L

import Foreign.Storable        (Storable, peek, sizeOf)
import Foreign.Ptr             (plusPtr, castPtr)
import Foreign.ForeignPtr      (withForeignPtr)
import Control.Monad.ST        (runST)
import Control.Monad.ST.Unsafe (unsafeInterleaveST)

import Control.Monad
import Control.Applicative
import Data.STRef
import Data.Word
import Data.Int
import Data.Bits
import Data.Maybe

import System.IO.Unsafe        (unsafePerformIO)

-- | The parse state
data S = S {-# UNPACK #-} !B.ByteString  -- current chunk
           L.ByteString                  -- the rest of the input
           {-# UNPACK #-} !Int64         -- bytes read

-- | The Get monad is just a State monad carrying around the input ByteString
newtype Parser a = Parser { unParser :: S -> Either String (a, S) }

instance Functor Parser where
    fmap f m = Parser $ \s -> case unParser m s of
      Left e -> Left e
      Right (a, s') -> Right (f a, s')
    
instance Monad Parser where
    return a  = Parser (\s -> Right (a, s))
    m >>= k   = Parser $ \s -> case (unParser m) s of
      Left e -> Left e
      Right (a, s') -> (unParser (k a)) s'
    fail  err  = Parser $ \(S _ _ bytes) ->
        Left (err ++ ". Failed reading at byte position " ++ show bytes)
instance MonadPlus Parser where
  mzero = Parser $ \_ -> Left []
  mplus p1 p2 = Parser $ \s -> case (unParser p1 s) of
    Left e1 -> case (unParser p2 s) of
      Left e2 -> Left (e1 ++ "\n" ++ e2)
      ok -> ok
    ok -> ok

instance Applicative Parser where
  pure  = return
  (<*>) = ap
  
instance Alternative Parser where
  empty = mzero
  (<|>) = mplus

------------------------------------------------------------------------

get :: Parser S
get   = Parser $ \s -> Right (s, s)

put :: S -> Parser ()
put s = Parser $ \_ -> Right ((), s)

------------------------------------------------------------------------

initState :: L.ByteString -> S
initState xs = mkState xs 0

mkState :: L.ByteString -> Int64 -> S
mkState l = case l of
    L.Empty      -> S B.empty L.empty
    L.Chunk x xs -> S x xs

-- | Run the Get monad applies a 'get'-based parser on the input ByteString
runParser :: Parser a -> L.ByteString -> Either String a
runParser m str = case unParser m (initState str) of
  Left e -> Left e
  Right (a, _) -> Right a

-- | Run the Get monad applies a 'get'-based parser on the input
-- ByteString. Additional to the result of get it returns the number of
-- consumed bytes and the rest of the input.
runParserState :: Parser a -> L.ByteString -> Int64 -> Either String (a, L.ByteString, Int64)
runParserState m str off =
    case unParser m (mkState str off) of
      Left e -> Left e
      Right (a, ~(S s ss newOff)) -> Right (a, s `bsJoin` ss, newOff)

------------------------------------------------------------------------

choice :: [Parser a] -> Parser a
choice = foldl (<|>) mzero

-- | Skip ahead @n@ bytes. Fails if fewer than @n@ bytes are available.
skip :: Word64 -> Parser ()
skip n = readN (fromIntegral n) (const ())

-- | Run @ga@, but return without consuming its input.
-- Fails if @ga@ fails.
lookAhead :: Parser a -> Parser a
lookAhead ga = do
    s <- get
    a <- ga
    put s
    return a

-- | Like 'lookAhead', but consume the input if @gma@ returns 'Just _'.
-- Fails if @gma@ fails.
lookAheadM :: Parser (Maybe a) -> Parser (Maybe a)
lookAheadM gma = do
    s <- get
    ma <- gma
    when (isNothing ma) $ put s
    return ma

-- | Like 'lookAhead', but consume the input if @gea@ returns 'Right _'.
-- Fails if @gea@ fails.
lookAheadE :: Parser (Either a b) -> Parser (Either a b)
lookAheadE gea = do
    s <- get
    ea <- gea
    case ea of
        Left _ -> put s
        _      -> return ()
    return ea

expect :: (Show a, Eq a) => (a -> Bool) -> Parser a -> Parser a
expect f p = do
  v <- p
  when (not $ f v) $ fail $ show v ++ " was not expected."
  return v

getString :: Int -> Parser String
getString l = do
  bs <- getLazyByteString (fromIntegral l)
  return $! map B.w2c (L.unpack bs)

getStringNul :: Parser String
getStringNul = do
  bs <- getLazyByteStringNul
  return $! map B.w2c (L.unpack bs)

string :: String -> Parser String
string s = expect (s ==) (getString $ length s)

-- Utility

-- | Get the total number of bytes read to this point.
bytesRead :: Parser Int64
bytesRead = do
    S _ _ b <- get
    return b

-- | Get the number of remaining unparsed bytes.
-- Useful for checking whether all input has been consumed.
-- Note that this forces the rest of the input.
remaining :: Parser Int64
remaining = do
    S s ss _ <- get
    return $! (fromIntegral (B.length s) + L.length ss)

-- | Test whether all input has been consumed,
-- i.e. there are no remaining unparsed bytes.
isEmpty :: Parser Bool
isEmpty = do
    S s ss _ <- get
    return $! (B.null s && L.null ss)

------------------------------------------------------------------------
-- Utility with ByteStrings

-- | An efficient 'get' method for strict ByteStrings. Fails if fewer
-- than @n@ bytes are left in the input.
getByteString :: Int -> Parser B.ByteString
getByteString n = readN n id

-- | An efficient 'get' method for lazy ByteStrings. Does not fail if fewer than
-- @n@ bytes are left in the input.
getLazyByteString :: Int64 -> Parser L.ByteString
getLazyByteString n = do
    S s ss bytes <- get
    let big = s `bsJoin` ss
    case splitAtST n big of
      (consume, rest) -> do put $ mkState rest (bytes + n)
                            return consume

-- | Get a lazy ByteString that is terminated with a NUL byte. Fails
-- if it reaches the end of input without hitting a NUL.
getLazyByteStringNul :: Parser L.ByteString
getLazyByteStringNul = do
    S s ss bytes <- get
    let big = s `bsJoin` ss
        (consume, t) = L.break (== 0) big
        (h, rest) = L.splitAt 1 t
    when (L.null h) $ fail "too few bytes"
    put $ mkState rest (bytes + L.length consume + 1)
    return consume

-- | Get the remaining bytes as a lazy ByteString
getRemainingLazyByteString :: Parser L.ByteString
getRemainingLazyByteString = do
    S s ss _ <- get
    return $! (s `bsJoin` ss)

------------------------------------------------------------------------
-- Helpers

-- | Pull @n@ bytes from the input, as a strict ByteString.
getBytes :: Int -> Parser B.ByteString
getBytes n = do
    S s ss bytes <- get
    if n <= B.length s
        then do let (consume,rest) = B.splitAt n s
                put $! S rest ss (bytes + fromIntegral n)
                return $! consume
        else
              case L.splitAt (fromIntegral n) (s `bsJoin` ss) of
                (consuming, rest) ->
                    do let now = B.concat . L.toChunks $ consuming
                       put $! mkState rest (bytes + fromIntegral n)
                       -- forces the next chunk before this one is returned
                       when (B.length now < n) $ fail "too few bytes"
                       return now

bsJoin :: B.ByteString -> L.ByteString -> L.ByteString
bsJoin bb lb
    | B.null bb = lb
    | otherwise = L.Chunk bb lb

-- | Split a ByteString. If the first result is consumed before the --
-- second, this runs in constant heap space.
--
-- You must force the returned tuple for that to work, e.g.
--
-- > case splitAtST n xs of
-- >    (ys,zs) -> consume ys ... consume zs
--
splitAtST :: Int64 -> L.ByteString -> (L.ByteString, L.ByteString)
splitAtST i ps | i <= 0 = (L.empty, ps)
splitAtST i ps          = runST (
     do r  <- newSTRef undefined
        xs <- first r i ps
        ys <- unsafeInterleaveST (readSTRef r)
        return (xs, ys))

  where
        first r 0 xs@(L.Chunk _ _) = writeSTRef r xs    >> return L.Empty
        first r _ L.Empty          = writeSTRef r L.Empty >> return L.Empty

        first r n (L.Chunk x xs)
          | n < l     = do writeSTRef r (L.Chunk (B.drop (fromIntegral n) x) xs)
                           return $! L.Chunk (B.take (fromIntegral n) x) L.Empty
          | otherwise = do writeSTRef r (L.drop (n - l) xs)
                           liftM (L.Chunk x) $ unsafeInterleaveST (first r (n - l) xs)

         where l = fromIntegral (B.length x)

-- Pull n bytes from the input, and apply a parser to those bytes,
-- yielding a value. If less than @n@ bytes are available, fail with an
-- error. This wraps @getBytes@.
readN :: Int -> (B.ByteString -> a) -> Parser a
readN n f = fmap f $ getBytes n


------------------------------------------------------------------------
-- Primtives

-- helper, get a raw Ptr onto a strict ByteString copied out of the
-- underlying lazy byteString. So many indirections from the raw parser
-- state that my head hurts...

getPtr :: Storable a => Int -> Parser a
getPtr n = do
    (fp,o,_) <- readN n B.toForeignPtr
    return . unsafePerformIO $ withForeignPtr fp $ \p -> peek (castPtr $ p `plusPtr` o)

------------------------------------------------------------------------

satisfy :: (Word8 -> Bool) -> Parser Word8
satisfy f = do
  w <- getWord8
  guard (f w)
  return w

-- | Read a Word8 from the monad state
getWord8 :: Parser Word8
getWord8 = getPtr (sizeOf (undefined :: Word8))

word8 :: Word8 -> Parser Word8
word8 w = expect (w ==) getWord8

-- | Read a Word16 in big endian format
getWord16be :: Parser Word16
getWord16be = do
    s <- readN 2 id
    return $! (fromIntegral (s `B.index` 0) `shiftL` 8) .|.
              (fromIntegral (s `B.index` 1))

word16be :: Word16 -> Parser Word16
word16be w = expect (w ==) getWord16be

-- | Read a Word16 in little endian format
getWord16le :: Parser Word16
getWord16le = do
    s <- readN 2 id
    return $! (fromIntegral (s `B.index` 1) `shiftL` 8) .|.
              (fromIntegral (s `B.index` 0) )

word16le :: Word16 -> Parser Word16
word16le w = expect (w ==) getWord16le

-- | Read a 24 bit word into Word32 in big endian format
getWord24be :: Parser Word32
getWord24be = do
    s <- readN 3 id
    return $! (fromIntegral (s `B.index` 0) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 2) )

word24be :: Word32 -> Parser Word32
word24be w = expect (w ==) getWord24be

getWord24le :: Parser Word32
getWord24le = do
    s <- readN 3 id
    return $! (fromIntegral (s `B.index` 2) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 0) )

word24le :: Word32 -> Parser Word32
word24le w = expect (w ==) getWord24le

-- | Read a Word32 in big endian format
getWord32be :: Parser Word32
getWord32be = do
    s <- readN 4 id
    return $! (fromIntegral (s `B.index` 0) `shiftL` 24) .|.
              (fromIntegral (s `B.index` 1) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 2) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 3) )

word32be :: Word32 -> Parser Word32
word32be w = expect (w ==) getWord32be

-- | Read a Word32 in little endian format
getWord32le :: Parser Word32
getWord32le = do
    s <- readN 4 id
    return $! (fromIntegral (s `B.index` 3) `shiftL` 24) .|.
              (fromIntegral (s `B.index` 2) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 0) )

word32le :: Word32 -> Parser Word32
word32le w = expect (w ==) getWord32le


-- | Read a Word64 in big endian format
getWord64be :: Parser Word64
getWord64be = do
    s <- readN 8 id
    return $! (fromIntegral (s `B.index` 0) `shiftL` 56) .|.
              (fromIntegral (s `B.index` 1) `shiftL` 48) .|.
              (fromIntegral (s `B.index` 2) `shiftL` 40) .|.
              (fromIntegral (s `B.index` 3) `shiftL` 32) .|.
              (fromIntegral (s `B.index` 4) `shiftL` 24) .|.
              (fromIntegral (s `B.index` 5) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 6) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 7) )

word64be :: Word64 -> Parser Word64
word64be w = expect (w ==) getWord64be

-- | Read a Word64 in little endian format
getWord64le :: Parser Word64
getWord64le = do
    s <- readN 8 id
    return $! (fromIntegral (s `B.index` 7) `shiftL` 56) .|.
              (fromIntegral (s `B.index` 6) `shiftL` 48) .|.
              (fromIntegral (s `B.index` 5) `shiftL` 40) .|.
              (fromIntegral (s `B.index` 4) `shiftL` 32) .|.
              (fromIntegral (s `B.index` 3) `shiftL` 24) .|.
              (fromIntegral (s `B.index` 2) `shiftL` 16) .|.
              (fromIntegral (s `B.index` 1) `shiftL`  8) .|.
              (fromIntegral (s `B.index` 0) )

word64le :: Word64 -> Parser Word64
word64le w = expect (w ==) getWord64le
------------------------------------------------------------------------
getInt8 :: Parser Int8
getInt8 = getWord8 >>= return . fromIntegral

int8 :: Int8 -> Parser Int8
int8 i = expect (i ==) getInt8

getInt16le :: Parser Int16
getInt16le = getWord16le >>= return . fromIntegral

int16le :: Int16 -> Parser Int16
int16le i = expect (i ==) getInt16le

getInt16be :: Parser Int16
getInt16be = getWord16be >>= return . fromIntegral

int16be :: Int16 -> Parser Int16
int16be i = expect (i ==) getInt16be

getInt32le :: Parser Int32
getInt32le = getWord32le >>= return . fromIntegral

int32le :: Int32 -> Parser Int32
int32le i = expect (i ==) getInt32le

getInt32be :: Parser Int32
getInt32be = getWord32be >>= return . fromIntegral

int32be :: Int32 -> Parser Int32
int32be i = expect (i ==) getInt32be

getInt64le :: Parser Int64
getInt64le = getWord64le >>= return . fromIntegral

int64le :: Int64 -> Parser Int64
int64le i = expect (i ==) getInt64le

getInt64be :: Parser Int64
getInt64be = getWord64be >>= return . fromIntegral

int64be :: Int64 -> Parser Int64
int64be i = expect (i ==) getInt64be

------------------------------------------------------------------------
-- Host-endian reads

-- | /O(1)./ Read a single native machine word. The word is read in
-- host order, host endian form, for the machine you're on. On a 64 bit
-- machine the Word is an 8 byte value, on a 32 bit machine, 4 bytes.
getWordHost :: Parser Word
getWordHost = getPtr (sizeOf (undefined :: Word))

wordHost :: Word -> Parser Word
wordHost w = expect (w ==) getWordHost

-- | /O(1)./ Read a 2 byte Word16 in native host order and host endianness.
getWord16host :: Parser Word16
getWord16host = getPtr (sizeOf (undefined :: Word16))

word16host :: Word16 -> Parser Word16
word16host w = expect (w ==) getWord16host

-- | /O(1)./ Read a Word32 in native host order and host endianness.
getWord32host :: Parser Word32
getWord32host = getPtr  (sizeOf (undefined :: Word32))

word32host :: Word32 -> Parser Word32
word32host w = expect (w ==) getWord32host

-- | /O(1)./ Read a Word64 in native host order and host endianess.
getWord64host   :: Parser Word64
getWord64host = getPtr  (sizeOf (undefined :: Word64))

word64host :: Word64 -> Parser Word64
word64host w = expect (w ==) getWord64host

-- Variable length numbers

getVarLenBe :: Parser Word64
getVarLenBe = f 0
  where
  f :: Word64 -> Parser Word64
  f acc =  do
    w <- getWord8 >>= return . fromIntegral
    if testBit w 7
      then f      $! (shiftL acc 7) .|. (clearBit w 7)
      else return $! (shiftL acc 7) .|. w

varLenBe :: Word64 -> Parser Word64
varLenBe a = expect (a ==) getVarLenBe

getVarLenLe :: Parser Word64
getVarLenLe = do
  w <- getWord8 >>= return . fromIntegral
  if testBit w 7
    then do
      w' <- getVarLenLe
      return $! (clearBit w 7) .|. (shiftL w' 7)
    else return $! w

varLenLe :: Word64 -> Parser Word64
varLenLe a = expect (a ==) getVarLenLe