{-# LANGUAGE BangPatterns #-} {-| Module : Database.MySQL.Protocol.Escape Description : Pure haskell mysql escape Copyright : (c) Winterland, 2016 License : BSD Maintainer : drkoster@qq.com Stability : experimental Portability : PORTABLE This module provide escape machinery for bytes and text types. reference: * Escape Sequence Character Represented by Sequence * \0 An ASCII NUL (X'00') character * \' A single quote (“'”) character * \" A double quote (“"”) character * \b A backspace character * \n A newline (linefeed) character * \r A carriage return character * \t A tab character * \Z ASCII 26 (Control+Z); see note following the table * \\ A backslash (“\”) character * \% A “%” character; see note following the table * \_ A “_” character; see note following the table The @\%@ and @\_@ sequences are used to search for literal instances of @%@ and @_@ in pattern-matching contexts where they would otherwise be interpreted as wildcard characters, so we won't auto escape @%@ or @_@ here. -} module Database.MySQL.Protocol.Escape where import Data.ByteString (ByteString) import qualified Data.ByteString.Internal as B import Data.Text (Text) import qualified Data.Text.Array as TA import qualified Data.Text.Internal as T import Data.Word import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr, minusPtr, plusPtr) import Foreign.Storable (peek, poke, pokeByteOff) import GHC.IO (unsafeDupablePerformIO) escapeText :: Text -> Text escapeText (T.Text arr off len) | len <= 0 = T.empty | otherwise = let (arr', len') = TA.run2 $ do marr <- TA.new (len * 2) loop arr (off + len) marr off 0 in T.Text arr' 0 len' where escape c marr ix = do TA.unsafeWrite marr ix 92 TA.unsafeWrite marr (ix+1) c loop oarr oend marr !ix !ix' | ix == oend = return (marr, ix') | otherwise = do let c = TA.unsafeIndex oarr ix go1 = loop oarr oend marr (ix+1) (ix'+1) go2 = loop oarr oend marr (ix+1) (ix'+2) if | c >= 0xD800 && c <= 0xDBFF -> do let c2 = TA.unsafeIndex oarr (ix+1) TA.unsafeWrite marr ix' c TA.unsafeWrite marr (ix'+1) c2 loop oarr oend marr (ix+2) (ix'+2) | c == 0 || c == 39 || c == 34 -> escape c marr ix' >> go2 -- \0 \' \" | c == 8 -> escape 98 marr ix' >> go2 -- \b | c == 10 -> escape 110 marr ix' >> go2 -- \n | c == 13 -> escape 114 marr ix' >> go2 -- \r | c == 9 -> escape 116 marr ix' >> go2 -- \t | c == 26 -> escape 90 marr ix' >> go2 -- \Z | c == 92 -> escape 92 marr ix' >> go2 -- \\ | otherwise -> TA.unsafeWrite marr ix' c >> go1 escapeBytes :: ByteString -> ByteString escapeBytes (B.PS fp s len) = unsafeDupablePerformIO $ withForeignPtr fp $ \ a -> B.createUptoN (len * 2) $ \ b -> do b' <- loop (a `plusPtr` s) (a `plusPtr` s `plusPtr` len) b return (b' `minusPtr` b) where escape :: Word8 -> Ptr Word8 -> IO (Ptr Word8) escape c p = do poke p 92 pokeByteOff p 1 c return (p `plusPtr` 2) loop !a aend !b | a == aend = return b | otherwise = do c <- peek a if | c == 0 || c == 39 || c == 34 -> escape c b >>= loop (a `plusPtr` 1) aend -- \0 \' \" | c == 8 -> escape 98 b >>= loop (a `plusPtr` 1) aend -- \b | c == 10 -> escape 110 b >>= loop (a `plusPtr` 1) aend -- \n | c == 13 -> escape 114 b >>= loop (a `plusPtr` 1) aend -- \r | c == 9 -> escape 116 b >>= loop (a `plusPtr` 1) aend -- \t | c == 26 -> escape 90 b >>= loop (a `plusPtr` 1) aend -- \Z | c == 92 -> escape 92 b >>= loop (a `plusPtr` 1) aend -- \\ | otherwise -> poke b c >> loop (a `plusPtr` 1) aend (b `plusPtr` 1)