{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}

{-|
Module      : Database.MySQL.Protocol.Escape
Description : Pure haskell mysql escape
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This module provide escape machinery for bytes and text types.

reference: <http://dev.mysql.com/doc/refman/5.7/en/string-literals.html>

    * Escape Sequence	Character Represented by Sequence
    * \0              	An ASCII NUL (X'00') character
    * \'              	A single quote (“'”) character
    * \"              	A double quote (“"”) character
    * \b              	A backspace character
    * \n              	A newline (linefeed) character
    * \r              	A carriage return character
    * \t              	A tab character
    * \Z              	ASCII 26 (Control+Z); see note following the table
    * \\              	A backslash (“\”) character
    * \%              	A “%” character; see note following the table
    * \_              	A “_” character; see note following the table

The @\%@ and @\_@ sequences are used to search for literal instances of @%@ and @_@ in pattern-matching contexts where they would otherwise be interpreted as wildcard characters, so we won't auto escape @%@ or @_@ here.

-}

module Database.MySQL.Protocol.Escape where

import           Control.Monad            (forM_)
import           Data.ByteString          (ByteString)
import qualified Data.ByteString.Internal as B
import           Data.Text                (Text)
import qualified Data.Text.Array          as TA
import qualified Data.Text.Internal       as T
import           Data.Word
import           Foreign.ForeignPtr       (withForeignPtr)
import           Foreign.Ptr              (Ptr, minusPtr, plusPtr)
import           Foreign.Storable         (peek, poke, pokeByteOff)
import           GHC.IO                   (unsafeDupablePerformIO)

escapeText :: Text -> Text
#if MIN_VERSION_text(2,0,0)
escapeText :: Text -> Text
escapeText (T.Text Array
arr Int
off Int
len)
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0  = Text
T.empty
    | Bool
otherwise =
        let (Array
arr', Int
len') =  (forall s. ST s (MArray s, Int)) -> (Array, Int)
forall a. (forall s. ST s (MArray s, a)) -> (Array, a)
TA.run2 ((forall s. ST s (MArray s, Int)) -> (Array, Int))
-> (forall s. ST s (MArray s, Int)) -> (Array, Int)
forall a b. (a -> b) -> a -> b
$ do
                MArray s
marr <- Int -> ST s (MArray s)
forall s. Int -> ST s (MArray s)
TA.new (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)
                Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
forall {s}.
Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) MArray s
marr Int
off Int
0
        in Array -> Int -> Int -> Text
T.Text Array
arr' Int
0 Int
len'
  where
    escape :: Word8 -> MArray s -> Int -> ST s ()
escape Word8
c MArray s
marr Int
ix = do
        MArray s -> Int -> Word8 -> ST s ()
forall s. MArray s -> Int -> Word8 -> ST s ()
TA.unsafeWrite MArray s
marr Int
ix Word8
92
        MArray s -> Int -> Word8 -> ST s ()
forall s. MArray s -> Int -> Word8 -> ST s ()
TA.unsafeWrite MArray s
marr (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Word8
c

    loop :: Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr !Int
ix !Int
ix'
        | Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
oend = (MArray s, Int) -> ST s (MArray s, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray s
marr, Int
ix')
        | Bool
otherwise  = do
            let c :: Word8
c = Array -> Int -> Word8
TA.unsafeIndex Array
oarr Int
ix
                cs :: [Word8]
cs = Word8
c Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: [ Array -> Int -> Word8
TA.unsafeIndex Array
oarr (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) | Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0xC0 ]
                      [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [ Array -> Int -> Word8
TA.unsafeIndex Array
oarr (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) | Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0xE0 ]
                      [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [ Array -> Int -> Word8
TA.unsafeIndex Array
oarr (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) | Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0xF0 ]
                go2 :: ST s (MArray s, Int)
go2 = Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix'Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2)
                goN :: ST s (MArray s, Int)
goN = do
                  [(Int, Word8)] -> ((Int, Word8) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Word8] -> [(Int, Word8)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..Int
4] [Word8]
cs) (((Int, Word8) -> ST s ()) -> ST s ())
-> ((Int, Word8) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
di,Word8
c') -> MArray s -> Int -> Word8 -> ST s ()
forall s. MArray s -> Int -> Word8 -> ST s ()
TA.unsafeWrite MArray s
marr (Int
ix' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
di) Word8
c'
                  Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Word8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
cs) (Int
ix' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Word8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
cs)
            if  | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
                    Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
39
                    Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
34 -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
c   MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \0 \' \"
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
8       -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
98  MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \b
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
10      -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
110 MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \n
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
13      -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
114 MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \r
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
9       -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
116 MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \t
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
26      -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
90  MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \Z
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
92      -> Word8 -> MArray s -> Int -> ST s ()
forall {s}. Word8 -> MArray s -> Int -> ST s ()
escape Word8
92  MArray s
marr Int
ix' ST s () -> ST s (MArray s, Int) -> ST s (MArray s, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \\

                | Bool
otherwise    -> ST s (MArray s, Int)
goN
#else
escapeText (T.Text arr off len)
    | len <= 0  = T.empty
    | otherwise =
        let (arr', len') =  TA.run2 $ do
                marr <- TA.new (len * 2)
                loop arr (off + len) marr off 0
        in T.Text arr' 0 len'
  where
    escape c marr ix = do
        TA.unsafeWrite marr ix 92
        TA.unsafeWrite marr (ix+1) c

    loop oarr oend marr !ix !ix'
        | ix == oend = return (marr, ix')
        | otherwise  = do
            let c = TA.unsafeIndex oarr ix
                go1 = loop oarr oend marr (ix+1) (ix'+1)
                go2 = loop oarr oend marr (ix+1) (ix'+2)
            if  | c >= 0xD800 && c <= 0xDBFF  -> do let c2 = TA.unsafeIndex oarr (ix+1)
                                                    TA.unsafeWrite marr ix' c
                                                    TA.unsafeWrite marr (ix'+1) c2
                                                    loop oarr oend marr (ix+2) (ix'+2)
                | c == 0
                    || c == 39
                    || c == 34 -> escape c   marr ix' >> go2 -- \0 \' \"
                | c == 8       -> escape 98  marr ix' >> go2 -- \b
                | c == 10      -> escape 110 marr ix' >> go2 -- \n
                | c == 13      -> escape 114 marr ix' >> go2 -- \r
                | c == 9       -> escape 116 marr ix' >> go2 -- \t
                | c == 26      -> escape 90  marr ix' >> go2 -- \Z
                | c == 92      -> escape 92  marr ix' >> go2 -- \\

                | otherwise    -> TA.unsafeWrite marr ix' c >> go1
#endif

escapeBytes :: ByteString -> ByteString
escapeBytes :: ByteString -> ByteString
escapeBytes (B.PS ForeignPtr Word8
fp Int
s Int
len) = IO ByteString -> ByteString
forall a. IO a -> a
unsafeDupablePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO ByteString) -> IO ByteString)
-> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
a ->
    Int -> (Ptr Word8 -> IO Int) -> IO ByteString
B.createUptoN (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
b -> do
        Ptr Word8
b' <- Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s) (Ptr Word8
a Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len) Ptr Word8
b
        Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
b' Ptr Word8 -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
b)
  where
    escape :: Word8 -> Ptr Word8 -> IO (Ptr Word8)
    escape :: Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
c Ptr Word8
p = do
        Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
92
        Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
p Int
1 Word8
c
        Ptr Word8 -> IO (Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2)

    loop :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop !Ptr Word8
a Ptr Word8
aend !Ptr Word8
b
        | Ptr Word8
a Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
aend = Ptr Word8 -> IO (Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr Word8
b
        | Bool
otherwise = do
            Word8
c <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
a
            if  | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
                    Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
39
                    Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
34 -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
c   Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \0 \' \"
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
8       -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
98  Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \b
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
10      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
110 Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \n
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
13      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
114 Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \r
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
9       -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
116 Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \t
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
26      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
90  Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \Z
                | Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
92      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
92  Ptr Word8
b IO (Ptr Word8) -> (Ptr Word8 -> IO (Ptr Word8)) -> IO (Ptr Word8)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \\

                | Bool
otherwise    -> Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
b Word8
c IO () -> IO (Ptr Word8) -> IO (Ptr Word8)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend (Ptr Word8
b Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)