{-# LANGUAGE Trustworthy #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE CPP #-} #include "MachDeps.h" -- | This module deals with escaping and sanitizing SQL templates. module Database.PostgreSQL.Escape ( fmtSql, quoteIdent , buildSql, buildSqlFromActions , buildAction, buildLiteral, buildByteA, buildIdent ) where import Blaze.ByteString.Builder.Internal.Write import Data.ByteString.Builder import Data.ByteString.Builder.Internal import Data.ByteString.Lazy (toStrict) import qualified Data.ByteString as S import qualified Data.ByteString.Internal as S import qualified Data.ByteString.Unsafe as S import Data.Monoid import Database.PostgreSQL.Simple import Database.PostgreSQL.Simple.ToField import Database.PostgreSQL.Simple.ToRow import Database.PostgreSQL.Simple.Types import Foreign.Marshal.Alloc (mallocBytes) import Foreign.Storable (pokeByteOff) import GHC.Prim (Addr#, and#, geAddr#, geWord#, Int#, int2Word# , minusAddr#, ord# , plusAddr#, readWord8OffAddr# , State# , uncheckedShiftRL#, word2Int#, writeWord8OffAddr# , Word#) import GHC.Ptr (Ptr(Ptr)) import GHC.Types (Char(C#), Int(I#), IO(IO)) import GHC.Word (Word8(W8#)) import System.IO.Unsafe (unsafeDupablePerformIO) {-# INLINE cmpres #-} -- | Newer versions of GHC return an Int# instead of a Bool for -- primitive comparison functions. The @cmpres@ function converts the -- result of such a comparison to a @Bool@. #if __GLASGOW_HASKELL__ >= 707 cmpres :: Int# -> Bool cmpres 0# = False cmpres _ = True #else /* __GLASGOW_HASKELL__ < 707 */ cmpres :: Bool -> Bool cmpres b = b #define cmpres(b) b #endif /* __GLASGOW_HASKELL__ < 707 */ c2b :: Char -> Word8 c2b (C# i) = W8# (int2Word# (ord# i)) c2b# :: Char -> Word# c2b# (C# i) = int2Word# (ord# i) inlinePerformIO :: IO a -> a #if MIN_VERSION_bytestring(0,10,6) inlinePerformIO = S.accursedUnutterablePerformIO #else inlinePerformIO = S.inlinePerformIO #endif fastFindIndex :: (Word# -> Bool) -> S.ByteString -> Maybe Int {-# INLINE fastFindIndex #-} fastFindIndex test bs = inlinePerformIO $ S.unsafeUseAsCStringLen bs $ \(Ptr bsp0, I# bsl0) -> do let bse = bsp0 `plusAddr#` bsl0 check bsp = IO $ \rw -> case readWord8OffAddr# bsp 0# rw of (# rw1, w #) -> (# rw1, test w #) go bsp | cmpres(bsp `geAddr#` bse) = return Nothing | otherwise = do match <- check bsp if match then return $ Just $ I# (bsp `minusAddr#` bsp0) else go (bsp `plusAddr#` 1#) go bsp0 fastBreak :: (Word# -> Bool) -> S.ByteString -> (S.ByteString, S.ByteString) {-# INLINE fastBreak #-} fastBreak test bs | Just n <- fastFindIndex test bs = (S.unsafeTake n bs, S.unsafeDrop n bs) | otherwise = (bs, S.empty) quoter :: S.ByteString -> S.ByteString -> (Word# -> Bool) -> (Word8 -> Builder) -> S.ByteString -> Builder {-# INLINE quoter #-} quoter start end escPred escFn bs0 = mconcat [byteStringCopy start, escaped bs0, byteStringCopy end] where escaped bs = case fastBreak escPred bs of (h, t) | S.null t -> byteString h | otherwise -> byteString h <> escFn (S.unsafeHead t) <> escaped (S.unsafeTail t) -- | Quote an identifier using unicode quoting syntax. This is -- necessary for identifiers containing a question mark, as otherwise -- "PostgreSQL.Simple"'s naive formatting code will attempt to match -- the question mark to a paremeter. uBuildIdent :: S.ByteString -> Builder uBuildIdent ident = quoter " U&\"" "\"" isSpecial esc ident where isSpecial 34## = True -- '"' isSpecial 63## = True -- '?' isSpecial 92## = True -- '\\' isSpecial _ = False esc c = byteStringCopy $ case () of _ | c == c2b '"' -> "\"\"" | c == c2b '?' -> "\\003f" | c == c2b '\\' -> "\\\\" | otherwise -> error "uquoteIdent" -- | Build a quoted identifier. Generally you will want to use -- 'quoteIdent', and for repeated use it will be faster to use -- @'byteString' . 'quoteIdent'@, but this internal function is -- exposed in case it is useful. buildIdent :: S.ByteString -> Builder buildIdent ident | Just _ <- fastFindIndex isQuestionmark ident = uBuildIdent ident | otherwise = quoter "\"" "\"" isDQuote (const $ byteStringCopy "\"\"") ident where isQuestionmark 63## = True isQuestionmark 0## = error "quoteIdent: illegal NUL character" isQuestionmark _ = False isDQuote 34## = True isDQuote _ = False -- | Quote an identifier such as a table or column name using -- double-quote characters. Note this has nothing to do with quoting -- /values/, which must be quoted using single quotes. (Anyway, all -- values should be quoted by 'query' or 'fmtSql'.) This function -- uses a unicode escape sequence to escape \'?\' characters, which -- would otherwise be expanded by 'query', 'formatQuery', or 'fmtSql'. -- -- >>> S8.putStrLn $ quoteIdent "hello \"world\"!" -- "hello ""world""!" -- >>> S8.putStrLn $ quoteIdent "hello \"world\"?" -- U&"hello ""world""\003f" -- -- Note that this quoting function is correct only if -- @client_encoding@ is @SQL_ASCII@, @client_coding@ is @UTF8@, or the -- identifier contains no multi-byte characters. For other coding -- schemes, this function may erroneously duplicate bytes that look -- like quote characters but are actually part of a multi-byte -- character code. In such cases, maliciously crafted identifiers -- will, even after quoting, allow injection of arbitrary SQL commands -- to the server. -- -- The upshot is that it is unwise to use this function on identifiers -- provided by untrustworthy sources. Note this is true anyway, -- regardless of @client_encoding@ setting, because certain \"system -- column\" names (e.g., @oid@, @tableoid@, @xmin@, @cmin@, @xmax@, -- @cmax@, @ctid@) are likely to produce unexpected results even when -- properly quoted. -- -- See 'Id' for a convenient way to include quoted identifiers in -- parameter lists. quoteIdent :: S.ByteString -> S.ByteString quoteIdent = toStrict . toLazyByteString . buildIdent hexNibblesPtr :: Ptr Word8 {-# NOINLINE hexNibblesPtr #-} hexNibblesPtr = unsafeDupablePerformIO $ do ptr <- mallocBytes 16 sequence_ $ zipWith (\o v -> pokeByteOff ptr o $ c2b v) [0..] (['0'..'9'] ++ ['a'..'f']) return ptr -- | Bad things will happen if the argument is greater than 0xff. uncheckedWriteNibbles# :: Addr# -> Word# -> State# d -> State# d {-# INLINE uncheckedWriteNibbles# #-} uncheckedWriteNibbles# p w rw0 = case (# word2Int# (w `uncheckedShiftRL#` 4# ) , word2Int# (w `and#` 0xf## ) #) of { (# h, l #) -> case readWord8OffAddr# nibbles h rw0 of { (# rw1, hascii #) -> case writeWord8OffAddr# p 0# hascii rw1 of { rw2 -> case readWord8OffAddr# nibbles l rw2 of { (# rw3, lascii #) -> writeWord8OffAddr# p 1# lascii rw3 }}}} where !(Ptr nibbles) = hexNibblesPtr hexCharEscBuilder :: Word8 -> Builder {-# INLINE hexCharEscBuilder #-} hexCharEscBuilder (W8# w) = fromWrite $ exactWrite 4 $ \(Ptr p) -> IO $ \rw0 -> (# uncheckedWriteNibbles# (p `plusAddr#` 2#) w (writeWord8OffAddr# p 1# (c2b# 'x') (writeWord8OffAddr# p 0# (c2b# '\\') rw0)) , () #) buildLiteral :: S.ByteString -> Builder buildLiteral = quoter " E'" "'" isSpecial esc where isSpecial 39## = True -- '\'' isSpecial 63## = True -- '?' isSpecial 92## = True -- '\\' isSpecial b = cmpres(b `geWord#` 128##) esc b | b == c2b '\'' = byteStringCopy "''" | b == c2b '\\' = byteStringCopy "\\\\" | otherwise = hexCharEscBuilder b copyByteToNibbles :: Addr# -> Addr# -> IO () {-# INLINE copyByteToNibbles #-} copyByteToNibbles src dst = IO $ \rw0 -> case readWord8OffAddr# src 0# rw0 of (# rw1, w #) -> (# uncheckedWriteNibbles# dst w rw1, () #) buildByteA :: S.ByteString -> Builder buildByteA bs = equote $ builder $ \cont (BufferRange (Ptr bb0) (Ptr be0)) -> S.unsafeUseAsCStringLen bs $ \(Ptr inptr0, I# inlen0) -> do let ine = plusAddr# inptr0 inlen0 fill oute inp outp | cmpres(inp `geAddr#` ine) = cont (BufferRange (Ptr outp) (Ptr oute)) | cmpres(plusAddr# outp 2# `geAddr#` oute) = return $ bufferFull (2 * (I# (ine `minusAddr#` inp)) + 1) (Ptr outp) $ \(BufferRange (Ptr bb) (Ptr be)) -> fill be inp bb | otherwise = do copyByteToNibbles inp outp fill oute (inp `plusAddr#` 1#) (outp `plusAddr#` 2#) fill be0 inptr0 bb0 where equote b = mconcat [byteString " E'\\\\x", b, char8 '\''] buildAction :: Action -> Builder buildAction (Plain b) = b buildAction (Escape bs) = buildLiteral bs buildAction (EscapeByteA bs) = buildByteA bs buildAction (EscapeIdentifier bs) = buildIdent bs buildAction (Many bs) = mconcat $ map buildAction bs -- | A lower-level function used by 'buildSql' and 'fmtSql'. You -- probably don't need to call it directly. buildSqlFromActions :: Query -> [Action] -> Builder buildSqlFromActions (Query template) actions = intercatlate (split template) (map buildAction $ actions) where intercatlate (t:ts) (p:ps) = t <> p <> intercatlate ts ps intercatlate [t] [] = t intercatlate _ _ = error $ "buildSql: wrong number of parameters for " ++ show template split s = case S.break (== c2b '?') s of (h,t) | S.null t -> [byteString h] | otherwise -> byteString h : split (S.unsafeTail t) -- | A builder version of 'fmtSql', possibly useful if you are about -- to concatenate various individually formatted query fragments and -- want to save the work of concatenating each individually. buildSql :: (ToRow p) => Query -> p -> Builder {-# INLINE buildSql #-} buildSql q p = buildSqlFromActions q (toRow p) -- | Take a SQL template containing \'?\' characters and a list of -- paremeters whose length must match the number of \'?\' characters, -- and format the result as an escaped 'S.ByteString' that can be used -- as a query. -- -- Like 'formatQuery', this function is naive about the placement of -- \'?\' characters and will expand all of them, even ones within -- quotes. To avoid this, you must use 'quoteIdent' on identifiers -- containing question marks. -- -- Also like 'formatQuery', \'?\' characters touching other \'?\' -- characters or quoted strings may do the wrong thing, and end up -- doubling a quote, so avoid substrings such as @\"??\"@ or -- @\"?'string'\"@, as these could get expanded to, e.g., -- @\"\'param''string'\"@, which is a single string containing an -- apostrophe, when you probably wanted two strings. fmtSql :: (ToRow p) => Query -> p -> Query {-# INLINE fmtSql #-} fmtSql q p = Query $ toStrict . toLazyByteString $ buildSql q p