{-# LANGUAGE CPP, RankNTypes, MagicHash, BangPatterns, TypeFamilies #-}

-- CPP C style pre-precessing, the #if defined lines
-- RankNTypes forall r. statement
-- MagicHash the (# unboxing #), also needs GHC.primitives

module Data.Binary.Get.Ext.Internal (

    -- * The Get e type
      Get
    , runCont
    , Decoder(..)
    , runGetIncremental

    , readN
    , readNWith

    -- * Parsing
    , bytesRead
    , totalBytesRead
    , isolate

    -- * With input chunks
    , withInputChunks
    , Consume
    , failOnEOF

    , get
    , put
    , ensureN

    -- * Utility
    , isEmpty
    , failG
    , lookAhead
    , lookAheadM
    , lookAheadE
    , label
    , onError
    , withError

    -- ** ByteStrings
    , getByteString

    ) where

import Foreign
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B

import Control.Applicative
import Control.Monad
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif

import Data.Binary.Internal ( accursedUnutterablePerformIO )

-- Kolmodin 20100427: at zurihac we discussed of having partial take a
-- "Maybe ByteString" and implemented it in this way.
-- The reasoning was that you could accidently provide an empty bytestring,
-- and it should not terminate the decoding (empty would mean eof).
-- However, I'd say that it's also a risk that you get stuck in a loop,
-- where you keep providing an empty string. Anyway, no new input should be
-- rare, as the RTS should only wake you up if you actually have some data
-- to read from your fd.

-- | A decoder produced by running a 'Get' monad.
data Decoder e a = Fail !B.ByteString (Either String e)
              -- ^ The decoder ran into an error. The decoder either used
              -- 'fail' or was not provided enough input.
              | Partial (Maybe B.ByteString -> Decoder e a)
              -- ^ The decoder has consumed the available input and needs
              -- more to continue. Provide 'Just' if more input is available
              -- and 'Nothing' otherwise, and you will get a new 'Decoder'.
              | Done !B.ByteString a
              -- ^ The decoder has successfully finished. Except for the
              -- output value you also get the unused input.
              | BytesRead {-# UNPACK #-} !Int64 (Int64 -> Decoder e a)
              -- ^ The decoder needs to know the current position in the input.
              -- Given the number of bytes remaning in the decoder, the outer
              -- decoder runner needs to calculate the position and
              -- resume the decoding.

-- unrolled codensity/state monad
newtype Get e a = C { runCont :: forall r. Int64 -> B.ByteString -> Success e a r -> Decoder e r }

type Success e a r = B.ByteString -> a -> Decoder e r

instance Monad (Get e) where
  return = pure
  (>>=) = bindG
#if MIN_VERSION_base(4,9,0)
  fail = Fail.fail

instance Fail.MonadFail (Get e) where
#endif
  fail = failG_

bindG :: Get e a -> (a -> Get e b) -> Get e b
bindG (C c) f = C $ \ge i ks -> c ge i (\i' a -> (runCont (f a)) ge i' ks)
{-# INLINE bindG #-}

failG_ :: String -> Get e a
failG_ str = C $ \_ i _ks -> Fail i $ Left str

failG :: e -> Get e a
failG err = C $ \_ i _ks -> Fail i $ Right err

apG :: Get e (a -> b) -> Get e a -> Get e b
apG d e = do
  b <- d
  a <- e
  return (b a)
{-# INLINE [0] apG #-}

fmapG :: (a -> b) -> Get e a -> Get e b
fmapG f m = C $ \ge i ks -> runCont m ge i (\i' a -> ks i' (f a))
{-# INLINE fmapG #-}

instance Applicative (Get e) where
  pure = \x -> C $ \_ s ks -> ks s x
  {-# INLINE [0] pure #-}
  (<*>) = apG
  {-# INLINE (<*>) #-}

instance MonadPlus (Get e) where
  mzero = empty
  mplus = (<|>)

instance Functor (Get e) where
  fmap = fmapG

instance Functor (Decoder e) where
  fmap f (Done s a) = Done s (f a)
  fmap f (Partial k) = Partial (fmap f . k)
  fmap _ (Fail s err) = Fail s err
  fmap f (BytesRead b k) = BytesRead b (fmap f . k)

instance (Show e, Show a) => Show (Decoder e a) where
  show (Fail _ err) = "Fail: " ++ show err
  show (Partial _) = "Partial _"
  show (Done _ a) = "Done: " ++ show a
  show (BytesRead _ _) = "BytesRead"

-- | Run a 'Get' monad. See 'Decoder' for what to do next, like providing
-- input, handling decoding errors and to get the output value.
runGetIncremental :: Int64 -> Get e a -> Decoder e a
runGetIncremental ge g = noMeansNo $
  runCont g ge B.empty (\i a -> Done i a)

-- | Make sure we don't have to pass Nothing to a Partial twice.
-- This way we don't need to pass around an EOF value in the Get e monad, it
-- can safely ask several times if it needs to.
noMeansNo :: Decoder e a -> Decoder e a
noMeansNo r0 = go r0
  where
  go r =
    case r of
      Partial k -> Partial $ \ms ->
                    case ms of
                      Just _ -> go (k ms)
                      Nothing -> neverAgain (k ms)
      BytesRead n k -> BytesRead n (go . k)
      Done _ _ -> r
      Fail _ _ -> r
  neverAgain r =
    case r of
      Partial k -> neverAgain (k Nothing)
      BytesRead n k -> BytesRead n (neverAgain . k)
      Fail _ _ -> r
      Done _ _ -> r

prompt :: B.ByteString -> Decoder e a -> (B.ByteString -> Decoder e a) -> Decoder e a
prompt inp kf ks = prompt' kf (\inp' -> ks (inp `B.append` inp'))

prompt' :: Decoder e a -> (B.ByteString -> Decoder e a) -> Decoder e a
prompt' kf ks =
  let loop =
        Partial $ \sm ->
          case sm of
            Just s | B.null s -> loop
                   | otherwise -> ks s
            Nothing -> kf
  in loop
  
getBaseOffset :: Get e Int64
getBaseOffset = C $ \ge s ks -> ks s ge

-- | Get e the total number of bytes read to this point.
totalBytesRead :: Get e Int64
totalBytesRead = do
  base_offset <- getBaseOffset
  offset <- bytesRead
  return $ base_offset + offset

-- | Get e the total number of bytes read to this point.
bytesRead :: Get e Int64
bytesRead = C $ \_ inp k -> BytesRead (fromIntegral $ B.length inp) (k inp)

-- | Isolate a decoder to operate with a fixed number of bytes, and fail if
-- fewer bytes were consumed, or more bytes were attempted to be consumed.
-- If the given decoder fails, 'isolate' will also fail.
-- Offset from 'bytesRead' will be relative to the start of 'isolate', not the
-- absolute of the input.
isolate :: Int   -- ^ The number of bytes that must be consumed
        -> Get e a -- ^ The decoder to isolate
        -> (Int -> e) -- ^ The error if fewer bytes were consumed
        -> Get e a
isolate n0 act err
  | n0 < 0 = fail "isolate: negative size"
  | otherwise = do
    ge <- getBaseOffset
    offset <- bytesRead
    go n0 (runCont act (ge + offset) B.empty Done)
  where
  go !n (Done left x)
    | n == 0 && B.null left = return x
    | otherwise = do
        pushFront left
        let consumed = n0 - n - B.length left
        failG $ err consumed
  go 0 (Partial resume) = go 0 (resume Nothing)
  go n (Partial resume) = do
    inp <- C $ \_ inp k -> do
      let takeLimited str =
            let (inp', out) = B.splitAt n str in
            k out (Just inp')
      case not (B.null inp) of
        True -> takeLimited inp
        False -> prompt inp (k B.empty Nothing) takeLimited
    case inp of
      Nothing -> go n (resume Nothing)
      Just str -> go (n - B.length str) (resume (Just str))
  go _ (Fail bs (Right ferr)) = pushFront bs >> failG ferr
  go _ (Fail bs (Left ferr)) = pushFront bs >> failG_ ferr
  go n (BytesRead r resume) =
    go n (resume $! fromIntegral n0 - fromIntegral n - r)

type Consume s = s -> B.ByteString -> Either s (B.ByteString, B.ByteString)

withInputChunks :: s -> Consume s -> ([B.ByteString] -> b) -> ([B.ByteString] -> Get e b) -> Get e b
withInputChunks initS consume onSucc onFail = go initS []
  where
  go state acc = C $ \ge inp ks ->
    case consume state inp of
      Left state' -> do
        let acc' = inp : acc
        prompt'
          (runCont (onFail (reverse acc')) ge B.empty ks)
          (\str' -> runCont (go state' acc') ge str' ks)
      Right (want,rest) -> do
        ks rest (onSucc (reverse (want:acc)))

failOnEOF :: [B.ByteString] -> Get () a
failOnEOF bs = C $ \_ _ _ -> Fail (B.concat bs) $ Right ()

-- | Test whether all input has been consumed, i.e. there are no remaining
-- undecoded bytes.
isEmpty :: Get e Bool
isEmpty = C $ \_ inp ks ->
    if B.null inp
      then prompt inp (ks inp True) (\inp' -> ks inp' False)
      else ks inp False

instance Alternative (Get e) where
  empty = C $ \_ inp _ks -> Fail inp $ Left "Data.Binary.Get(Alternative).empty"
  {-# INLINE empty #-}
  (<|>) f g = do
    (decoder, bs) <- runAndKeepTrack f
    case decoder of
      Done inp x -> C $ \_ _ ks -> ks inp x
      Fail _ _ -> pushBack bs >> g
      _ -> error "Binary: impossible"
  {-# INLINE (<|>) #-}
  some p = (:) <$> p <*> many p
  {-# INLINE some #-}
  many p = do
    v <- (Just <$> p) <|> pure Nothing
    case v of
      Nothing -> pure []
      Just x -> (:) x <$> many p
  {-# INLINE many #-}

-- | Run a decoder and keep track of all the input it consumes.
-- Once it's finished, return the final decoder (always 'Done' or 'Fail'),
-- and unconsume all the the input the decoder required to run.
-- Any additional chunks which was required to run the decoder
-- will also be returned.
runAndKeepTrack :: Get e a -> Get e (Decoder e a, [B.ByteString])
runAndKeepTrack g = C $ \ge inp ks ->
  let r0 = runCont g ge inp (\inp' a -> Done inp' a)
      go !acc r = case r of
                    Done inp' a -> ks inp (Done inp' a, reverse acc)
                    Partial k -> Partial $ \minp -> go (maybe acc (:acc) minp) (k minp)
                    Fail inp' s -> ks inp (Fail inp' s, reverse acc)
                    BytesRead unused k -> BytesRead unused (go acc . k)
  in go [] r0
{-# INLINE runAndKeepTrack #-}

pushBack :: [B.ByteString] -> Get e ()
pushBack [] = C $ \_  inp ks -> ks inp ()
pushBack bs = C $ \_  inp ks -> ks (B.concat (inp : bs)) ()
{-# INLINE pushBack #-}

pushFront :: B.ByteString -> Get e ()
pushFront bs = C $ \_  inp ks -> ks (B.append bs inp) ()
{-# INLINE pushFront #-}

-- | Run the given decoder, but without consuming its input. If the given
-- decoder fails, then so will this function.
--
-- /Since: 0.7.0.0/
lookAhead :: Get e a -> Get e a
lookAhead g = do
  (decoder, bs) <- runAndKeepTrack g
  case decoder of
    Done _ a -> pushBack bs >> return a
    Fail inp s -> C $ \_ _ _ -> Fail inp s
    _ -> error "Binary: impossible"

-- | Run the given decoder, and only consume its input if it returns 'Just'.
-- If 'Nothing' is returned, the input will be unconsumed.
-- If the given decoder fails, then so will this function.
--
-- /Since: 0.7.0.0/
lookAheadM :: Get e (Maybe a) -> Get e (Maybe a)
lookAheadM g = do
  let g' = maybe (Left ()) Right <$> g
  either (const Nothing) Just <$> lookAheadE g'

-- | Run the given decoder, and only consume its input if it returns 'Right'.
-- If 'Left' is returned, the input will be unconsumed.
-- If the given decoder fails, then so will this function.
lookAheadE :: Get e (Either a b) -> Get e (Either a b)
lookAheadE g = do
  (decoder, bs) <- runAndKeepTrack g
  case decoder of
    Done _ (Left x) -> pushBack bs >> return (Left x)
    Done inp (Right x) -> C $ \_ _ ks -> ks inp (Right x)
    Fail inp s -> C $ \_ _ _ -> Fail inp s
    _ -> error "Binary: impossible"

--- | Label a decoder. If the decoder fails, the label will be appended on
--- a new line to the error message string.
label :: String -> Get String a -> Get String a
label msg = onError (\x -> x ++ "\n" ++ msg)

-- | Convert decoder error. If the decoder fails, the given function will be applied
-- to the error message.
onError :: (e -> e') -> Get e a -> Get e' a
onError msg decoder = C $ \ge inp ks ->
  let r0 = runCont decoder ge inp (\inp' a -> Done inp' a)
      go r = case r of
                 Done inp' a -> ks inp' a
                 Partial k -> Partial (go . k)
                 Fail inp' (Left s) -> Fail inp' $ Left s
                 Fail inp' (Right s) -> Fail inp' $ Right $ msg s
                 BytesRead u k -> BytesRead u (go . k)
  in go r0
  
-- | Set decoder error. If the decoder fails, the given error will be used
-- as the error message.
withError :: Get () a -> e -> Get e a
withError decoder msg = onError (const msg) decoder

------------------------------------------------------------------------
-- ByteStrings
--

-- | An efficient get method for strict ByteStrings. Fails if fewer than @n@
-- bytes are left in the input. If @n <= 0@ then the empty string is returned.
getByteString :: Int -> Get () B.ByteString
getByteString n | n > 0 = readN n (B.unsafeTake n)
                | otherwise = return B.empty
{-# INLINE getByteString #-}

-- | Get e the current chunk.
get :: Get e B.ByteString
get = C $ \_ inp ks -> ks inp inp

-- | Replace the current chunk.
put :: B.ByteString -> Get e ()
put s = C $ \_ _inp ks -> ks s ()

-- | Return at least @n@ bytes, maybe more. If not enough data is available
-- the computation will escape with 'Partial'.
readN :: Int -> (B.ByteString -> a) -> Get () a
readN !n f = ensureN n >> unsafeReadN n f
{-# INLINE [0] readN #-}

{-# RULES

"readN/readN merge" forall n m f g.
  apG (readN n f) (readN m g) = readN (n+m) (\bs -> f bs $ g (B.unsafeDrop n bs)) #-}

-- | Ensure that there are at least @n@ bytes available. If not, the
-- computation will escape with 'Partial'.
ensureN :: Int -> Get () ()
ensureN !n0 = C $ \ge inp ks -> do
  if B.length inp >= n0
    then ks inp ()
    else runCont (withInputChunks n0 enoughChunks onSucc onFail >>= put) ge inp ks
  where -- might look a bit funny, but plays very well with GHC's inliner.
        -- GHC won't inline recursive functions, so we make ensureN non-recursive
    enoughChunks n str
      | B.length str >= n = Right (str,B.empty)
      | otherwise = Left (n - B.length str)
    -- Sometimes we will produce leftovers lists of the form [B.empty, nonempty]
    -- where `nonempty` is a non-empty ByteString. In this case we can avoid a copy
    -- by simply dropping the empty prefix. In principle ByteString might want
    -- to gain this optimization as well
    onSucc = B.concat . dropWhile B.null
    onFail bss = C $ \_ _ _ -> Fail (B.concat bss) $ Right ()
{-# INLINE ensureN #-}

unsafeReadN :: Int -> (B.ByteString -> a) -> Get () a
unsafeReadN !n f = C $ \_ inp ks -> do
  ks (B.unsafeDrop n inp) $! f inp -- strict return

-- | @readNWith n f@ where @f@ must be deterministic and not have side effects.
readNWith :: Int -> (Ptr a -> IO a) -> Get () a
readNWith n f = do
    -- It should be safe to use accursedUnutterablePerformIO here.
    -- The action must be deterministic and not have any external side effects.
    -- It depends on the value of the ByteString so the value dependencies look OK.
    readN n $ \s -> accursedUnutterablePerformIO $ B.unsafeUseAsCString s (f . castPtr)
{-# INLINE readNWith #-}