{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- This file is included by "Data.ByteString.ReadInt", after defining
-- "BYTESTRING_STRICT".  The two modules are largely identical, except for the
-- choice of ByteString type and the loops in `readNatural`, where the lazy
-- version needs to nest the inner loop inside a loop over the constituent
-- chunks.

#ifdef BYTESTRING_STRICT
module Data.ByteString.ReadNat
#else
module Data.ByteString.Lazy.ReadNat
#endif
    ( readInteger
    , readNatural
    ) where

import qualified Data.ByteString.Internal as BI
#ifdef BYTESTRING_STRICT
import Data.ByteString
#else
import Data.ByteString.Lazy
import Data.ByteString.Lazy.Internal
#endif
import Data.Bits (finiteBitSize)
import Data.ByteString.Internal (pattern BS, plusForeignPtr)
import Data.Word
import Foreign.ForeignPtr (ForeignPtr)
import Foreign.Ptr (Ptr, minusPtr, plusPtr)
import Foreign.Storable (Storable(..))
import Numeric.Natural (Natural)

----- Public API

-- | 'readInteger' reads an 'Integer' from the beginning of the 'ByteString'.
-- If there is no 'Integer' at the beginning of the string, it returns
-- 'Nothing', otherwise it just returns the 'Integer' read, and the rest of
-- the string.
--
-- 'readInteger' does not ignore leading whitespace, the value must start
-- immediately at the beginning of the input string.
--
-- ==== __Examples__
-- >>> readInteger "-000111222333444555666777888999 all done"
-- Just (-111222333444555666777888999," all done")
-- >>> readInteger "+1: readInteger also accepts a leading '+'"
-- Just (1, ": readInteger also accepts a leading '+'")
-- >>> readInteger "not a decimal number"
-- Nothing
--
readInteger :: ByteString -> Maybe (Integer, ByteString)
readInteger :: ByteString -> Maybe (Integer, ByteString)
readInteger = \ ByteString
bs -> do
    (Word8
w, ByteString
s) <- ByteString -> Maybe (Word8, ByteString)
uncons ByteString
bs
    let d :: Word
d = Word8 -> Word
fromDigit Word8
w
    if | Word
d Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<=    Word
9 -> Word -> ByteString -> Maybe (Integer, ByteString)
unsigned Word
d ByteString
s -- leading digit
       | Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x2d -> ByteString -> Maybe (Integer, ByteString)
negative ByteString
s   -- minus sign
       | Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x2b -> ByteString -> Maybe (Integer, ByteString)
positive ByteString
s   -- plus sign
       | Bool
otherwise -> Maybe (Integer, ByteString)
forall a. Maybe a
Nothing      -- not a number
  where
    unsigned :: Word -> ByteString -> Maybe (Integer, ByteString)
    unsigned :: Word -> ByteString -> Maybe (Integer, ByteString)
unsigned Word
d ByteString
s =
         let (!Natural
n, ByteString
rest) = Word -> ByteString -> (Natural, ByteString)
_readDecimal Word
d ByteString
s
             !i :: Integer
i = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger Natural
n
          in (Integer, ByteString) -> Maybe (Integer, ByteString)
forall a. a -> Maybe a
Just (Integer
i, ByteString
rest)

    positive :: ByteString -> Maybe (Integer, ByteString)
    positive :: ByteString -> Maybe (Integer, ByteString)
positive ByteString
bs = do
        (Word8
w, ByteString
s) <- ByteString -> Maybe (Word8, ByteString)
uncons ByteString
bs
        let d :: Word
d = Word8 -> Word
fromDigit Word8
w
        if | Word
d Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<=    Word
9 -> Word -> ByteString -> Maybe (Integer, ByteString)
unsigned Word
d ByteString
s
           | Bool
otherwise -> Maybe (Integer, ByteString)
forall a. Maybe a
Nothing

    negative :: ByteString -> Maybe (Integer, ByteString)
    negative :: ByteString -> Maybe (Integer, ByteString)
negative ByteString
bs = do
        (Word8
w, ByteString
s) <- ByteString -> Maybe (Word8, ByteString)
uncons ByteString
bs
        let d :: Word
d = Word8 -> Word
fromDigit Word8
w
        if | Word
d Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
>     Word
9 -> Maybe (Integer, ByteString)
forall a. Maybe a
Nothing
           | Bool
otherwise -> let (Natural
n, ByteString
rest) = Word -> ByteString -> (Natural, ByteString)
_readDecimal Word
d ByteString
s
                              !i :: Integer
i = Integer -> Integer
forall a. Num a => a -> a
negate (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Natural -> Integer
forall a. Integral a => a -> Integer
toInteger Natural
n
                           in (Integer, ByteString) -> Maybe (Integer, ByteString)
forall a. a -> Maybe a
Just (Integer
i, ByteString
rest)

-- | 'readNatural' reads a 'Natural' number from the beginning of the
-- 'ByteString'.  If there is no 'Natural' number at the beginning of the
-- string, it returns 'Nothing', otherwise it just returns the number read, and
-- the rest of the string.
--
-- 'readNatural' does not ignore leading whitespace, the value must start with
-- a decimal digit immediately at the beginning of the input string.  Leading
-- @+@ signs are not accepted.
--
-- ==== __Examples__
-- >>> readNatural "000111222333444555666777888999 all done"
-- Just (111222333444555666777888999," all done")
-- >>> readNatural "+000111222333444555666777888999 explicit sign"
-- Nothing
-- >>> readNatural "not a decimal number"
-- Nothing
--
readNatural :: ByteString -> Maybe (Natural, ByteString)
readNatural :: ByteString -> Maybe (Natural, ByteString)
readNatural ByteString
bs = do
    (Word8
w, ByteString
s) <- ByteString -> Maybe (Word8, ByteString)
uncons ByteString
bs
    let d :: Word
d = Word8 -> Word
fromDigit Word8
w
    if | Word
d Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<=    Word
9 -> (Natural, ByteString) -> Maybe (Natural, ByteString)
forall a. a -> Maybe a
Just ((Natural, ByteString) -> Maybe (Natural, ByteString))
-> (Natural, ByteString) -> Maybe (Natural, ByteString)
forall a b. (a -> b) -> a -> b
$! Word -> ByteString -> (Natural, ByteString)
_readDecimal Word
d ByteString
s
       | Bool
otherwise -> Maybe (Natural, ByteString)
forall a. Maybe a
Nothing

----- Internal implementation

-- | Intermediate result from scanning a chunk, final output is
-- obtained via `convert` after all the chunks are processed.
--
data Result = Result !Int      -- Bytes consumed
                     !Word     -- Value of LSW
                     !Int      -- Digits in LSW
                     [Natural] -- Little endian MSW list

_readDecimal :: Word -> ByteString -> (Natural, ByteString)
_readDecimal :: Word -> ByteString -> (Natural, ByteString)
_readDecimal =
    -- Having read one digit, we're about to read the 2nd So the digit count
    -- up to 'safeLog' starts at 2.
    [Natural] -> Int -> Word -> ByteString -> (Natural, ByteString)
consume [] Int
2
  where
    consume :: [Natural] -> Int -> Word -> ByteString
            -> (Natural, ByteString)
#ifdef BYTESTRING_STRICT
    consume ns cnt acc (BS fp len) =
        -- Having read one digit, we're about to read the 2nd
        -- So the digit count up to 'safeLog' starts at 2.
        case natdigits fp len acc cnt ns of
            Result used acc' cnt' ns'
                | used == len
                  -> convert acc' cnt' ns' $ empty
                | otherwise
                  -> convert acc' cnt' ns' $
                     BS (fp `plusForeignPtr` used) (len - used)
#else
    -- All done
    consume :: [Natural] -> Int -> Word -> ByteString -> (Natural, ByteString)
consume [Natural]
ns Int
cnt Word
acc ByteString
Empty = Word -> Int -> [Natural] -> ByteString -> (Natural, ByteString)
forall {b}. Word -> Int -> [Natural] -> b -> (Natural, b)
convert Word
acc Int
cnt [Natural]
ns ByteString
Empty
    -- Process next chunk
    consume [Natural]
ns Int
cnt Word
acc (Chunk (BS ForeignPtr Word8
fp Int
len) ByteString
cs)
        = case ForeignPtr Word8 -> Int -> Word -> Int -> [Natural] -> Result
natdigits ForeignPtr Word8
fp Int
len Word
acc Int
cnt [Natural]
ns of
            Result Int
used Word
acc' Int
cnt' [Natural]
ns'
                | Int
used Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len -- process more chunks
                  -> [Natural] -> Int -> Word -> ByteString -> (Natural, ByteString)
consume [Natural]
ns' Int
cnt' Word
acc' ByteString
cs
                | Bool
otherwise   -- ran into a non-digit
                  -> let c :: ByteString
c = ByteString -> ByteString -> ByteString
Chunk (ForeignPtr Word8 -> Int -> ByteString
BS (ForeignPtr Word8
fp ForeignPtr Word8 -> Int -> ForeignPtr Word8
forall a b. ForeignPtr a -> Int -> ForeignPtr b
`plusForeignPtr` Int
used) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
used)) ByteString
cs
                      in Word -> Int -> [Natural] -> ByteString -> (Natural, ByteString)
forall {b}. Word -> Int -> [Natural] -> b -> (Natural, b)
convert Word
acc' Int
cnt' [Natural]
ns' ByteString
c
#endif
    convert :: Word -> Int -> [Natural] -> b -> (Natural, b)
convert !Word
acc !Int
cnt ![Natural]
ns b
rest =
        let !n :: Natural
n = Word -> Int -> [Natural] -> Natural
combine Word
acc Int
cnt [Natural]
ns
         in (Natural
n, b
rest)

    -- | Merge least-significant word with reduction of of little-endian tail.
    --
    -- The input is:
    --
    -- * Least significant digits as a 'Word' (LSW)
    -- * The number of digits that went into the LSW
    -- * All the remaining digit groups ('safeLog' digits each),
    --   in little-endian order
    --
    -- The result is obtained by pairwise recursive combining of all the
    -- full size digit groups, followed by multiplication by @10^cnt@ and
    -- addition of the LSW.
    combine :: Word      -- ^ value of LSW
            -> Int       -- ^ count of digits in LSW
            -> [Natural] -- ^ tail elements (base @10^'safeLog'@)
            -> Natural
    {-# INLINE combine #-}
    combine :: Word -> Int -> [Natural] -> Natural
combine !Word
acc !Int
_   [] = Word -> Natural
wordToNatural Word
acc
    combine !Word
acc !Int
cnt [Natural]
ns =
        Word -> Natural
wordToNatural (Word
10Word -> Int -> Word
forall a b. (Num a, Integral b) => a -> b -> a
^Int
cnt) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* Natural -> [Natural] -> Natural
combine1 Natural
safeBase [Natural]
ns Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Word -> Natural
wordToNatural Word
acc

    -- | Recursive reduction of little-endian sequence of 'Natural'-valued
    -- /digits/ in base @base@ (a power of 10).  The base is squared after
    -- each round.  This shows better asymptotic performance than one word
    -- at a time multiply-add folds.  See:
    -- <https://gmplib.org/manual/Multiplication-Algorithms>
    --
    combine1 :: Natural -> [Natural] -> Natural
    combine1 :: Natural -> [Natural] -> Natural
combine1 Natural
_    [Natural
n] = Natural
n
    combine1 Natural
base [Natural]
ns  = Natural -> [Natural] -> Natural
combine1 (Natural
base Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* Natural
base) (Natural -> [Natural] -> [Natural]
combine2 Natural
base [Natural]
ns)

    -- | One round pairwise merge of numbers in base @base@.
    combine2 :: Natural -> [Natural] -> [Natural]
    combine2 :: Natural -> [Natural] -> [Natural]
combine2 Natural
base (Natural
n:Natural
m:[Natural]
ns) = let !t :: Natural
t = Natural
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* Natural
base Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
n in Natural
t Natural -> [Natural] -> [Natural]
forall a. a -> [a] -> [a]
: Natural -> [Natural] -> [Natural]
combine2 Natural
base [Natural]
ns
    combine2 Natural
_    [Natural]
ns       = [Natural]
ns

-- The intermediate representation is a little-endian sequence in base
-- @10^'safeLog'@, prefixed by an initial element in base @10^cnt@ for some
-- @cnt@ between 1 and 'safeLog'.  The final result is obtained by recursive
-- pairwise merging of the tail followed by a final multiplication by @10^cnt@
-- and addition of the head.
--
natdigits :: ForeignPtr Word8 -- ^ Input chunk
          -> Int              -- ^ Chunk length
          -> Word             -- ^ accumulated element
          -> Int              -- ^ partial digit count
          -> [Natural]        -- ^ accumulated MSB elements
          -> Result
{-# INLINE natdigits #-}
natdigits :: ForeignPtr Word8 -> Int -> Word -> Int -> [Natural] -> Result
natdigits ForeignPtr Word8
fp Int
len = \ Word
acc Int
cnt [Natural]
ns ->
    IO Result -> Result
forall a. IO a -> a
BI.accursedUnutterablePerformIO (IO Result -> Result) -> IO Result -> Result
forall a b. (a -> b) -> a -> b
$
        ForeignPtr Word8 -> (Ptr Word8 -> IO Result) -> IO Result
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Result) -> IO Result)
-> (Ptr Word8 -> IO Result) -> IO Result
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
ptr -> do
            let end :: Ptr b
end = Ptr Word8
ptr Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len
            Ptr Word8
-> Ptr Word8 -> Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
forall {b}.
Ptr b
-> Ptr Word8 -> Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
go Ptr Word8
ptr Ptr Word8
forall {b}. Ptr b
end Word
acc Int
cnt [Natural]
ns Ptr Word8
ptr
  where
    go :: Ptr b
-> Ptr Word8 -> Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
go !Ptr b
start !Ptr Word8
end = Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
loop
      where
        loop :: Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
        loop :: Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
loop !Word
acc !Int
cnt [Natural]
ns !Ptr Word8
ptr = IO Word
getDigit IO Word -> (Word -> IO Result) -> IO Result
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ !Word
d ->
            if | Word
d Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
> Word
9
                 -> Result -> IO Result
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result -> IO Result) -> Result -> IO Result
forall a b. (a -> b) -> a -> b
$ Int -> Word -> Int -> [Natural] -> Result
Result (Ptr Word8
ptr Ptr Word8 -> Ptr b -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr b
start) Word
acc Int
cnt [Natural]
ns
               | Int
cnt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
safeLog
                 -> Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
loop (Word
10Word -> Word -> Word
forall a. Num a => a -> a -> a
*Word
acc Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
d) (Int
cntInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Natural]
ns (Ptr Word8 -> IO Result) -> Ptr Word8 -> IO Result
forall a b. (a -> b) -> a -> b
$ Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
               | Bool
otherwise
                 -> let !acc' :: Natural
acc' = Word -> Natural
wordToNatural Word
acc
                     in Word -> Int -> [Natural] -> Ptr Word8 -> IO Result
loop Word
d Int
1 (Natural
acc' Natural -> [Natural] -> [Natural]
forall a. a -> [a] -> [a]
: [Natural]
ns) (Ptr Word8 -> IO Result) -> Ptr Word8 -> IO Result
forall a b. (a -> b) -> a -> b
$ Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
          where
            getDigit :: IO Word
getDigit | Ptr Word8
ptr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Ptr Word8
end = Word8 -> Word
fromDigit (Word8 -> Word) -> IO Word8 -> IO Word
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
ptr
                     | Bool
otherwise  = Word -> IO Word
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Word
10  -- End of input
            {-# NOINLINE getDigit #-}
            -- 'getDigit' makes it possible to implement a single success
            -- exit point from the loop.  If instead we return 'Result'
            -- from multiple places, when 'natdigits' is inlined we get (at
            -- least GHC 8.10 through 9.2) for each exit path a separate
            -- join point implementing the continuation code.  GHC ticket
            -- <https://gitlab.haskell.org/ghc/ghc/-/issues/20739>.
            --
            -- The NOINLINE pragma is required to avoid inlining branches
            -- that would restore multiple exit points.

----- Misc functions

-- | Largest decimal digit count that never overflows the accumulator
-- The base 10 logarithm of 2 is ~0.30103, therefore 2^n has at least
-- @1 + floor (0.3 n)@ decimal digits.  Therefore @floor (0.3 n)@,
-- digits cannot overflow the upper bound of an @n-bit@ word.
--
safeLog :: Int
safeLog :: Int
safeLog = Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* forall b. FiniteBits b => b -> Int
finiteBitSize @Word Word
0 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
10

-- | 10-power base for little-endian sequence of ~Word-sized "digits"
safeBase :: Natural
safeBase :: Natural
safeBase = Natural
10 Natural -> Int -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
safeLog

fromDigit :: Word8 -> Word
{-# INLINE fromDigit #-}
fromDigit :: Word8 -> Word
fromDigit = \ !Word8
w -> Word8 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
0x30 -- i.e. w - '0'

wordToNatural :: Word -> Natural
{-# INLINE wordToNatural #-}
wordToNatural :: Word -> Natural
wordToNatural  = Word -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral