{-# LANGUAGE BangPatterns #-}

module Codec.Serialise.Internal.GeneralisedUTF8
    ( encodeGenUTF8
    , UTF8Encoding(..)
    , decodeGenUTF8
      -- * Utilities
    , isSurrogate
    , isValid
    ) where

import Control.Monad.ST
import Data.Bits
import Data.Char
import Data.Word
import qualified Codec.CBOR.ByteArray.Sliced as BAS
import Data.Primitive.ByteArray

data UTF8Encoding = ConformantUTF8 | GeneralisedUTF8
                  deriving (Int -> UTF8Encoding -> ShowS
[UTF8Encoding] -> ShowS
UTF8Encoding -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UTF8Encoding] -> ShowS
$cshowList :: [UTF8Encoding] -> ShowS
show :: UTF8Encoding -> String
$cshow :: UTF8Encoding -> String
showsPrec :: Int -> UTF8Encoding -> ShowS
$cshowsPrec :: Int -> UTF8Encoding -> ShowS
Show, UTF8Encoding -> UTF8Encoding -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UTF8Encoding -> UTF8Encoding -> Bool
$c/= :: UTF8Encoding -> UTF8Encoding -> Bool
== :: UTF8Encoding -> UTF8Encoding -> Bool
$c== :: UTF8Encoding -> UTF8Encoding -> Bool
Eq)

-- | Is a 'Char' a UTF-16 surrogate?
isSurrogate :: Char -> Bool
isSurrogate :: Char -> Bool
isSurrogate Char
c = Char
c forall a. Ord a => a -> a -> Bool
>= Char
'\xd800' Bool -> Bool -> Bool
&& Char
c forall a. Ord a => a -> a -> Bool
<= Char
'\xdfff'

-- | Encode a string as (generalized) UTF-8. In addition to the encoding, we
-- return a flag indicating whether the encoded string contained any surrogate
-- characters, in which case the output is generalized UTF-8.
encodeGenUTF8 :: String -> (BAS.SlicedByteArray, UTF8Encoding)
encodeGenUTF8 :: String -> (SlicedByteArray, UTF8Encoding)
encodeGenUTF8 String
st = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    -- We slightly over-allocate such that we won't need to copy in the
    -- ASCII-only case.
    MutableByteArray s
ba <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length String
st forall a. Num a => a -> a -> a
+ Int
4)
    forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba UTF8Encoding
ConformantUTF8 Int
0 String
st
  where
    go :: MutableByteArray s -> UTF8Encoding
       -> Int -> [Char]
       -> ST s (BAS.SlicedByteArray, UTF8Encoding)
    go :: forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba !UTF8Encoding
enc !Int
off  [] = do
        ByteArray
ba' <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
ba
        forall (m :: * -> *) a. Monad m => a -> m a
return (ByteArray -> Int -> Int -> SlicedByteArray
BAS.SBA ByteArray
ba' Int
0 Int
off, UTF8Encoding
enc)
    go MutableByteArray s
ba UTF8Encoding
enc Int
off  (Char
c:String
cs)
      | Int
off forall a. Num a => a -> a -> a
+ Int
4 forall a. Ord a => a -> a -> Bool
>= Int
cap = do
        -- We ran out of room; reallocate and copy
        MutableByteArray s
ba' <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
cap forall a. Num a => a -> a -> a
+ Int
cap forall a. Integral a => a -> a -> a
`div` Int
2 forall a. Num a => a -> a -> a
+ Int
1)
        forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
copyMutableByteArray MutableByteArray s
ba' Int
0 MutableByteArray s
ba Int
0 Int
off
        forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba' UTF8Encoding
enc Int
off (Char
cforall a. a -> [a] -> [a]
:String
cs)

      | Char
c forall a. Ord a => a -> a -> Bool
>= Char
'\x10000' = do
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
0) (Word8
0xf0 forall a. Bits a => a -> a -> a
.|. (Word8
0x07 forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte Int
18))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
1) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte Int
12))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
2) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
6))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
3) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
0))
        forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba UTF8Encoding
enc (Int
offforall a. Num a => a -> a -> a
+Int
4) String
cs

      | Char
c forall a. Ord a => a -> a -> Bool
>= Char
'\x0800'  = do
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
0) (Word8
0xe0 forall a. Bits a => a -> a -> a
.|. (Word8
0x0f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte Int
12))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
1) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
6))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
2) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
0))

        -- Is this a surrogate character?
        let enc' :: UTF8Encoding
enc'
              | Char -> Bool
isSurrogate Char
c = UTF8Encoding
GeneralisedUTF8
              | Bool
otherwise     = UTF8Encoding
enc
        forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba UTF8Encoding
enc' (Int
offforall a. Num a => a -> a -> a
+Int
3) String
cs

      | Char
c forall a. Ord a => a -> a -> Bool
>= Char
'\x0080'  = do
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
0) (Word8
0xc0 forall a. Bits a => a -> a -> a
.|. (Word8
0x1f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
6))
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba (Int
offforall a. Num a => a -> a -> a
+Int
1) (Word8
0x80 forall a. Bits a => a -> a -> a
.|. (Word8
0x3f forall a. Bits a => a -> a -> a
.&. Int -> Word8
shiftedByte  Int
0))
        forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba UTF8Encoding
enc (Int
offforall a. Num a => a -> a -> a
+Int
2) String
cs

      | Char
c forall a. Ord a => a -> a -> Bool
<= Char
'\x007f'  = do
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
ba Int
off (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Word8)
        forall s.
MutableByteArray s
-> UTF8Encoding
-> Int
-> String
-> ST s (SlicedByteArray, UTF8Encoding)
go MutableByteArray s
ba UTF8Encoding
enc (Int
offforall a. Num a => a -> a -> a
+Int
1) String
cs

      | Bool
otherwise      = forall a. HasCallStack => String -> a
error String
"encodeGenUTF8: Impossible"
      where
        cap :: Int
cap = forall s. MutableByteArray s -> Int
sizeofMutableByteArray MutableByteArray s
ba
        n :: Int
n = Char -> Int
ord Char
c
        shiftedByte :: Int -> Word8
        shiftedByte :: Int -> Word8
shiftedByte Int
shft = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
shft

decodeGenUTF8 :: ByteArray -> String
decodeGenUTF8 :: ByteArray -> String
decodeGenUTF8 ByteArray
ba = Int -> String
go Int
0
  where
    !len :: Int
len = ByteArray -> Int
sizeofByteArray ByteArray
ba

    index :: Int -> Int
    index :: Int -> Int
index Int
i = forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteArray
ba forall a. Prim a => ByteArray -> Int -> a
`indexByteArray` Int
i :: Word8)

    go :: Int -> String
go !Int
off
      | Int
off forall a. Eq a => a -> a -> Bool
== Int
len = []

      | Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0xf8 forall a. Eq a => a -> a -> Bool
== Int
0xf0 =
        let n1 :: Int
n1 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
1)
            n2 :: Int
n2 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
2)
            n3 :: Int
n3 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
3)
            c :: Char
c  = Int -> Char
chr forall a b. (a -> b) -> a -> b
$  (Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0x07) forall a. Bits a => a -> Int -> a
`shiftL` Int
18
                    forall a. Bits a => a -> a -> a
.|. (Int
n1 forall a. Bits a => a -> a -> a
.&. Int
0x3f) forall a. Bits a => a -> Int -> a
`shiftL` Int
12
                    forall a. Bits a => a -> a -> a
.|. (Int
n2 forall a. Bits a => a -> a -> a
.&. Int
0x3f) forall a. Bits a => a -> Int -> a
`shiftL`  Int
6
                    forall a. Bits a => a -> a -> a
.|. (Int
n3 forall a. Bits a => a -> a -> a
.&. Int
0x3f)
        in Char
c forall a. a -> [a] -> [a]
: Int -> String
go (Int
off forall a. Num a => a -> a -> a
+ Int
4)

      | Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0xf0 forall a. Eq a => a -> a -> Bool
== Int
0xe0 =
        let n1 :: Int
n1 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
1)
            n2 :: Int
n2 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
2)
            c :: Char
c  = Int -> Char
chr forall a b. (a -> b) -> a -> b
$  (Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0x0f) forall a. Bits a => a -> Int -> a
`shiftL` Int
12
                    forall a. Bits a => a -> a -> a
.|. (Int
n1 forall a. Bits a => a -> a -> a
.&. Int
0x3f) forall a. Bits a => a -> Int -> a
`shiftL`  Int
6
                    forall a. Bits a => a -> a -> a
.|. (Int
n2 forall a. Bits a => a -> a -> a
.&. Int
0x3f)
        in Char
c forall a. a -> [a] -> [a]
: Int -> String
go (Int
off forall a. Num a => a -> a -> a
+ Int
3)

      | Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0xe0 forall a. Eq a => a -> a -> Bool
== Int
0xc0 =
        let n1 :: Int
n1 = Int -> Int
index (Int
off forall a. Num a => a -> a -> a
+ Int
1)
            c :: Char
c  = Int -> Char
chr forall a b. (a -> b) -> a -> b
$  (Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0x1f) forall a. Bits a => a -> Int -> a
`shiftL`  Int
6
                    forall a. Bits a => a -> a -> a
.|. (Int
n1 forall a. Bits a => a -> a -> a
.&. Int
0x3f)
        in Char
c forall a. a -> [a] -> [a]
: Int -> String
go (Int
off forall a. Num a => a -> a -> a
+ Int
2)

      | Bool
otherwise =
        let c :: Char
c =  Int -> Char
chr forall a b. (a -> b) -> a -> b
$  (Int
n0 forall a. Bits a => a -> a -> a
.&. Int
0x7f)
        in Char
c forall a. a -> [a] -> [a]
: Int -> String
go (Int
off forall a. Num a => a -> a -> a
+ Int
1)
      where !n0 :: Int
n0 = Int -> Int
index Int
off

-- | Is the given byte sequence valid under the given encoding?
isValid :: UTF8Encoding -> [Word8] -> Bool
isValid :: UTF8Encoding -> [Word8] -> Bool
isValid UTF8Encoding
encoding = forall {a}. (Ord a, Num a) => [a] -> Bool
go
  where
    go :: [a] -> Bool
go [] = Bool
True
    go (a
b0:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0x00 a
0x7f a
b0 = [a] -> Bool
go [a]
bs
    go (a
b0:a
b1:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0xc2 a
0xdf a
b0
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b1 = [a] -> Bool
go [a]
bs
    go (a
0xe0:a
b1:a
b2:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0xa0 a
0xbf a
b1
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b2 = [a] -> Bool
go [a]
bs
    go (a
0xed:a
b1:[a]
_)
      -- surrogate range
      | UTF8Encoding
encoding forall a. Eq a => a -> a -> Bool
== UTF8Encoding
ConformantUTF8
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0xa0 a
0xbf a
b1
      = Bool
False
    go (a
b0:a
b1:a
b2:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0xe1 a
0xef a
b0
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b1
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b2 = [a] -> Bool
go [a]
bs
    go (a
0xf0:a
b1:a
b2:a
b3:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0x90 a
0xbf a
b1
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b2
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b3 = [a] -> Bool
go [a]
bs
    go (a
b0:a
b1:a
b2:a
b3:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0xf1 a
0xf3 a
b0
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b1
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b2
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b3 = [a] -> Bool
go [a]
bs
    go (a
0xf4:a
b1:a
b2:a
b3:[a]
bs)
      | forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0x8f a
b1
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b2
      , forall a. Ord a => a -> a -> a -> Bool
inRange a
0x80 a
0xbf a
b3 = [a] -> Bool
go [a]
bs
    go [a]
_ = Bool
False

inRange :: Ord a => a -> a -> a -> Bool
inRange :: forall a. Ord a => a -> a -> a -> Bool
inRange a
lower a
upper a
x = a
lower forall a. Ord a => a -> a -> Bool
<= a
x Bool -> Bool -> Bool
&& a
x forall a. Ord a => a -> a -> Bool
<= a
upper