{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE MagicHash                #-}
{-# LANGUAGE UnliftedFFITypes         #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf   #-}

module Data.JsonStream.Unescape (
  unescapeText
) where

import           Data.ByteString            as B
import           Data.ByteString.Internal   as B hiding (c2w)
import           Data.Text.Encoding.Error   (UnicodeException (..))
import           Data.Text.Internal         (Text (..))
import           Data.Text.Unsafe           (unsafeDupablePerformIO)
import           Data.Word                  (Word8, Word32)
import           Foreign.ForeignPtr         (withForeignPtr)
import           Foreign.Ptr                (Ptr, plusPtr)
import           Foreign.Storable           (peek)

#if MIN_VERSION_text(2,0,0)

import qualified Data.Primitive           as P
import qualified Data.Text.Array          as T
import qualified Data.Text.Internal       as T
import           Data.Bits                (shiftL, shiftR, (.&.), (.|.))
import           Control.Exception        (try, throwIO)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.ForeignPtr (plusForeignPtr)

#else

import           Control.Exception          (evaluate, throw, try)
import           Control.Monad.ST.Unsafe    (unsafeIOToST, unsafeSTToIO)
import           Data.Text.Internal.Private (runText)
import           Foreign.Marshal.Utils      (with)
import qualified Data.Text.Array            as A
import           GHC.Base                   (MutableByteArray#)
import           Foreign.C.Types            (CInt (..), CSize (..))

#endif

#if !MIN_VERSION_text(2,0,0)

foreign import ccall unsafe "_jstream_decode_string" c_js_decode
    :: MutableByteArray# s -> Ptr CSize
    -> Ptr Word8 -> Ptr Word8 -> IO CInt

unescapeText' :: ByteString -> Text
unescapeText' :: ByteString -> Text
unescapeText' (PS ForeignPtr Word8
fp Int
off Int
len) = (forall s. (MArray s -> Int -> ST s Text) -> ST s Text) -> Text
runText ((forall s. (MArray s -> Int -> ST s Text) -> ST s Text) -> Text)
-> (forall s. (MArray s -> Int -> ST s Text) -> ST s Text) -> Text
forall a b. (a -> b) -> a -> b
$ \MArray s -> Int -> ST s Text
done -> do
  let go :: MArray s -> IO Text
go MArray s
dest = ForeignPtr Word8 -> (Ptr Word8 -> IO Text) -> IO Text
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Text) -> IO Text)
-> (Ptr Word8 -> IO Text) -> IO Text
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
        CSize -> (Ptr CSize -> IO Text) -> IO Text
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (CSize
0::CSize) ((Ptr CSize -> IO Text) -> IO Text)
-> (Ptr CSize -> IO Text) -> IO Text
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
destOffPtr -> do
          let end :: Ptr b
end = Ptr Word8
ptr Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len)
              loop :: Ptr Word8 -> IO Text
loop Ptr Word8
curPtr = do
                CInt
res <- MutableByteArray# s
-> Ptr CSize -> Ptr Word8 -> Ptr Word8 -> IO CInt
forall s.
MutableByteArray# s
-> Ptr CSize -> Ptr Word8 -> Ptr Word8 -> IO CInt
c_js_decode (MArray s -> MutableByteArray# s
forall s. MArray s -> MutableByteArray# s
A.maBA MArray s
dest) Ptr CSize
destOffPtr Ptr Word8
curPtr Ptr Word8
forall b. Ptr b
end
                case CInt
res of
                  CInt
0 -> do
                    CSize
n <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
destOffPtr
                    ST s Text -> IO Text
forall s a. ST s a -> IO a
unsafeSTToIO (MArray s -> Int -> ST s Text
done MArray s
dest (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
n))
                  CInt
_ ->
                    UnicodeException -> IO Text
forall a e. Exception e => e -> a
throw (String -> Maybe Word8 -> UnicodeException
DecodeError String
desc Maybe Word8
forall a. Maybe a
Nothing)
          Ptr Word8 -> IO Text
loop (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)
  (IO Text -> ST s Text
forall a s. IO a -> ST s a
unsafeIOToST (IO Text -> ST s Text)
-> (MArray s -> IO Text) -> MArray s -> ST s Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MArray s -> IO Text
go) (MArray s -> ST s Text) -> ST s (MArray s) -> ST s Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> ST s (MArray s)
forall s. Int -> ST s (MArray s)
A.new Int
len
 where
  desc :: String
desc = String
"Data.JsonStream.Unescape.unescapeText': Invalid UTF-8 stream"
{-# INLINE unescapeText' #-}

unescapeText :: ByteString -> Either UnicodeException Text
unescapeText :: ByteString -> Either UnicodeException Text
unescapeText = IO (Either UnicodeException Text) -> Either UnicodeException Text
forall a. IO a -> a
unsafeDupablePerformIO (IO (Either UnicodeException Text) -> Either UnicodeException Text)
-> (ByteString -> IO (Either UnicodeException Text))
-> ByteString
-> Either UnicodeException Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Text -> IO (Either UnicodeException Text)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Text -> IO (Either UnicodeException Text))
-> (ByteString -> IO Text)
-> ByteString
-> IO (Either UnicodeException Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> IO Text
forall a. a -> IO a
evaluate (Text -> IO Text) -> (ByteString -> Text) -> ByteString -> IO Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
unescapeText'
{-# INLINE unescapeText #-}

#else

withBS :: ByteString -> (ForeignPtr Word8 -> Int -> r) -> r
#if MIN_VERSION_bytestring(0,11,0)
withBS (BS !sfp !slen)       kont = kont sfp slen
#else
withBS (PS !sfp !soff !slen) kont = kont (plusForeignPtr sfp soff) slen
#endif
{-# INLINE withBS #-}

unescapeText :: ByteString -> Either UnicodeException Text
unescapeText = unsafeDupablePerformIO . try . unescapeTextIO

throwDecodeError :: IO a
throwDecodeError =
  let desc = "Data.Text.Internal.Encoding.decodeUtf8: Invalid UTF-8 stream"
   in throwIO (DecodeError desc Nothing)

-- The following is copied from aeson-2.0 

-------------------------------------------------------------------------------
-- unescapeTextIO
-------------------------------------------------------------------------------

-- This function is generated using staged-streams
-- See: https://github.com/phadej/staged/blob/master/staged-streams-unicode/src/Unicode/JSON.hs
--
-- Because @aeson@ better to not use template-haskell itself,
-- we dump the splice and prettify it by hand a bit.
--
unescapeTextIO :: ByteString -> IO Text
unescapeTextIO bs = withBS bs $ \fptr len ->
  withForeignPtr fptr $ \begin -> do
    let end :: Ptr Word8
        end = plusPtr begin len

    arr <- P.newPrimArray len

    let write3bytes :: Int -> Word8 -> Word8 -> Word8 -> Ptr Word8 -> IO Text
        write3bytes !out !b1 !b2 !b3 !inp = do
          P.writePrimArray arr out b1
          write2bytes (out + 1) b2 b3 inp

        write2bytes :: Int -> Word8 -> Word8 -> Ptr Word8 -> IO Text
        write2bytes !out !b1 !b2 !inp = do
          P.writePrimArray arr out b1
          write1byte (out + 1) b2 inp

        write1byte :: Int -> Word8 -> Ptr Word8 -> IO Text
        write1byte !out !b1 !inp = do
          P.writePrimArray arr out b1
          state_start (out + 1) inp

        writeCodePoint :: Int -> Ptr Word8 -> Word32 -> IO Text
        writeCodePoint !out !inp !acc
          | acc <= 127 = do
            P.writePrimArray arr out (fromIntegral acc :: Word8)
            state_start (out + 1) (plusPtr inp 1)

          | acc <= 2047 = do
            let b1 = fromIntegral (shiftR acc 6 .|. 192) :: Word8
            let b2 = fromIntegral ((acc .&. 63) .|. 128) :: Word8
            P.writePrimArray arr out b1
            write1byte (out + 1) b2 (plusPtr inp 1)

          | acc <= 65535 = do
            let b1 = fromIntegral (shiftR acc 12 .|. 224) :: Word8
            let b2 = fromIntegral ((shiftR acc 6 .&. 63) .|.  128) :: Word8
            let b3 = fromIntegral ((acc .&. 63) .|. 128) :: Word8
            P.writePrimArray arr out b1
            write2bytes (out + 1) b2 b3 (plusPtr inp 1)

          | otherwise = do
            let b1 = fromIntegral (shiftR acc 18 .|. 240) :: Word8
            let b2 = fromIntegral ((shiftR acc 12 .&. 63) .|. 128) :: Word8
            let b3 = fromIntegral ((shiftR acc 6 .&. 63) .|. 128) :: Word8
            let b4 = fromIntegral ((acc .&. 63) .|. 128) :: Word8
            P.writePrimArray arr out b1
            write3bytes (out + 1) b2 b3 b4 (plusPtr inp 1)

        state_sudone :: Int -> Ptr Word8 -> Word32 -> Word32 -> IO Text
        state_sudone !out !inp !hi !lo
          | 56320 <= lo, lo <= 57343
          = writeCodePoint out inp (65536 + (shiftL (hi - 55296) 10 .|.  (lo - 56320)))
          
          | otherwise
          = throwDecodeError

        state_su4 :: Int -> Ptr Word8 -> Word32 -> Word32 -> IO Text
        state_su4 !out !inp !hi !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_sudone out inp hi (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_sudone out inp hi (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_sudone out inp hi (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_su3 :: Int -> Ptr Word8 -> Word32 -> Word32 -> IO Text
        state_su3 !out !inp !hi !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_su4 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_su4 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_su4 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_su2 :: Int -> Ptr Word8 -> Word32 -> Word32 -> IO Text
        state_su2 !out !inp !hi !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 -> 
                 state_su3 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_su3 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_su3 out (plusPtr inp 1) hi (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_su1 :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_su1 !out !inp !hi
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_su2 out (plusPtr inp 1) hi (fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_su2 out (plusPtr inp 1) hi (fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_su2 out (plusPtr inp 1) hi (fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_su :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_su !out !inp !hi
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            case w8 of
              117 -> state_su1 out (plusPtr inp 1) hi
              _   -> throwDecodeError

        state_ss :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_ss !out !inp !hi
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            case w8 of
              92 -> state_su out (plusPtr inp 1) hi
              _  -> throwDecodeError

        state_udone :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_udone !out !inp !acc
          | acc < 55296 || acc > 57343 =
            writeCodePoint out inp acc

          | acc < 56320 =
            state_ss out (plusPtr inp 1) acc

          | otherwise =
            throwDecodeError

        state_u4 :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_u4 !out !inp !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_udone out inp (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_udone out inp (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_udone out inp (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_u3 :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_u3 !out !inp !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_u4 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_u4 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_u4 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_u2 :: Int -> Ptr Word8 -> Word32 -> IO Text
        state_u2 !out !inp !acc
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_u3 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_u3 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_u3 out (plusPtr inp 1) (shiftL acc 4 .|. fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_u1 :: Int -> Ptr Word8 -> IO Text
        state_u1 !out !inp
          | inp == end = throwDecodeError
          | otherwise = do
            w8 <- peek inp
            if | 48 <= w8, w8 <= 57 ->
                 state_u2 out (plusPtr inp 1) (fromIntegral (w8 - 48))
               | 65 <= w8, w8 <= 70 ->
                 state_u2 out (plusPtr inp 1) (fromIntegral (w8 - 55))
               | 97 <= w8, w8 <= 102 ->
                 state_u2 out (plusPtr inp 1) (fromIntegral (w8 - 87))
               | otherwise ->
                 throwDecodeError

        state_escape :: Int -> Ptr Word8 -> IO Text
        state_escape !out !inp
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            case w8 of
              34 -> do
                P.writePrimArray arr out 34
                state_start (out + 1) (plusPtr inp 1)

              92 -> do
                P.writePrimArray arr out 92
                state_start (out + 1) (plusPtr inp 1)

              47 -> do
                P.writePrimArray arr out 47
                state_start (out + 1) (plusPtr inp 1)

              98 -> do
                P.writePrimArray arr out 8
                state_start (out + 1) (plusPtr inp 1)

              102 -> do
                P.writePrimArray arr out 12
                state_start (out + 1) (plusPtr inp 1)

              110 -> do
                P.writePrimArray arr out 10
                state_start (out + 1) (plusPtr inp 1)

              114 -> do
                P.writePrimArray arr out 13
                state_start (out + 1) (plusPtr inp 1)

              116 -> do
                P.writePrimArray arr out 9
                state_start (out + 1) (plusPtr inp 1)

              117 ->
                state_u1 out (plusPtr inp 1)

              _ -> throwDecodeError

        state_input4c :: Int -> Ptr Word8 -> Word8 -> Word8 -> Word8 -> IO Text
        state_input4c !out !inp !b1 !b2 !b3
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128
               , let acc    = shiftL (fromIntegral (b1 .&. 7)) 18
               , let acc'   = acc .|. shiftL (fromIntegral (b2 .&. 63)) 12
               , let acc''  = acc' .|. shiftL (fromIntegral (b3 .&. 63)) 6
               , let acc''' = acc'' .|. fromIntegral (w8 .&. 63) :: Word32
               , acc''' >= 65536 && acc''' < 1114112 -> do
                 P.writePrimArray arr out b1
                 write3bytes (out + 1) b2 b3 w8 (plusPtr inp 1)

               | otherwise ->
                 throwDecodeError

        state_input4b :: Int -> Ptr Word8 -> Word8 -> Word8 -> IO Text
        state_input4b !out !inp !b1 !b2
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128 ->
                 state_input4c out (plusPtr inp 1) b1 b2 w8

               | otherwise ->
                 throwDecodeError

        state_input4 :: Int -> Ptr Word8 -> Word8 -> IO Text
        state_input4 !out !inp !b1
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128 ->
                 state_input4b out (plusPtr inp 1) b1 w8

               | otherwise ->
                 throwDecodeError

        state_input3b :: Int -> Ptr Word8 -> Word8 -> Word8 -> IO Text
        state_input3b !out !inp !b1 !b2
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128
               , let acc   = shiftL (fromIntegral (b1 .&. 15)) 12
               , let acc'  = acc .|.  shiftL (fromIntegral (b2 .&. 63)) 6
               , let acc'' = acc' .|. fromIntegral (w8 .&. 63) :: Word32
               , (acc'' >= 2048 && acc'' < 55296) || acc'' > 57343 -> do
                 P.writePrimArray arr out b1
                 write2bytes (out + 1) b2 w8 (plusPtr inp 1)

               | otherwise ->
                 throwDecodeError

        state_input3 :: Int -> Ptr Word8 -> Word8 -> IO Text
        state_input3 !out !inp !b1
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128 ->
                 state_input3b out (plusPtr inp 1) b1 w8

               | otherwise ->
                 throwDecodeError

        state_input2 :: Int -> Ptr Word8 -> Word8 -> IO Text
        state_input2 !out !inp !b1
          | inp == end = throwDecodeError
          | otherwise  = do
            w8 <- peek inp
            if | (w8 .&. 192) == 128,
                 let acc = shiftL (fromIntegral (b1 .&. 63)) 6 :: Word32
                     acc' = acc .|. fromIntegral (w8 .&. 63) :: Word32
               , acc' >= 128 -> do
                 P.writePrimArray arr out b1
                 write1byte (out + 1) w8 (plusPtr inp 1)

               | otherwise ->
                 throwDecodeError

        state_start :: Int -> Ptr Word8 -> IO Text
        state_start !out !inp
          | inp == end = do
            P.shrinkMutablePrimArray arr out
            frozenArr <- P.unsafeFreezePrimArray arr
            return $ case frozenArr of
              P.PrimArray ba -> T.Text (T.ByteArray ba) 0 out

          | otherwise = do
            w8 <- peek inp
            if | w8 == 92 -> state_escape out (plusPtr inp 1)
               | w8 < 128 -> do
                 P.writePrimArray arr out w8
                 state_start (out + 1) (plusPtr inp 1)

               | w8 < 192 -> throwDecodeError
               | w8 < 224 -> state_input2 out (plusPtr inp 1) w8
               | w8 < 240 -> state_input3 out (plusPtr inp 1) w8
               | w8 < 248 -> state_input4 out (plusPtr inp 1) w8

               | otherwise -> throwDecodeError

    -- start the state machine
    state_start (0 :: Int) begin

#endif