module Database.MySQL.Query where

import           Data.String               (IsString (..))
import           Control.Exception         (throw, Exception)
import           Data.Typeable
import qualified Data.ByteString.Lazy      as L
import qualified Data.ByteString.Lazy.Char8     as LC
import qualified Data.ByteString.Builder   as BB
import           Control.Arrow             (first)
import           Database.MySQL.Protocol.MySQLValue
import           Data.Binary.Put

-- | Query string type borrowed from @mysql-simple@.
--
-- This type is intended to make it difficult to
-- construct a SQL query by concatenating string fragments, as that is
-- an extremely common way to accidentally introduce SQL injection
-- vulnerabilities into an application.
--
-- This type is an instance of 'IsString', so the easiest way to
-- construct a query is to enable the @OverloadedStrings@ language
-- extension and then simply write the query in double quotes.
--
-- The underlying type is a 'L.ByteString', and literal Haskell strings
-- that contain Unicode characters will be correctly transformed to
-- UTF-8.
--
newtype Query = Query { Query -> ByteString
fromQuery :: L.ByteString } deriving (Query -> Query -> Bool
(Query -> Query -> Bool) -> (Query -> Query -> Bool) -> Eq Query
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Query -> Query -> Bool
$c/= :: Query -> Query -> Bool
== :: Query -> Query -> Bool
$c== :: Query -> Query -> Bool
Eq, Eq Query
Eq Query
-> (Query -> Query -> Ordering)
-> (Query -> Query -> Bool)
-> (Query -> Query -> Bool)
-> (Query -> Query -> Bool)
-> (Query -> Query -> Bool)
-> (Query -> Query -> Query)
-> (Query -> Query -> Query)
-> Ord Query
Query -> Query -> Bool
Query -> Query -> Ordering
Query -> Query -> Query
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Query -> Query -> Query
$cmin :: Query -> Query -> Query
max :: Query -> Query -> Query
$cmax :: Query -> Query -> Query
>= :: Query -> Query -> Bool
$c>= :: Query -> Query -> Bool
> :: Query -> Query -> Bool
$c> :: Query -> Query -> Bool
<= :: Query -> Query -> Bool
$c<= :: Query -> Query -> Bool
< :: Query -> Query -> Bool
$c< :: Query -> Query -> Bool
compare :: Query -> Query -> Ordering
$ccompare :: Query -> Query -> Ordering
$cp1Ord :: Eq Query
Ord, Typeable)

instance Show Query where
    show :: Query -> String
show = ByteString -> String
forall a. Show a => a -> String
show (ByteString -> String) -> (Query -> ByteString) -> Query -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Query -> ByteString
fromQuery

instance Read Query where
    readsPrec :: Int -> ReadS Query
readsPrec Int
i = ((ByteString, String) -> (Query, String))
-> [(ByteString, String)] -> [(Query, String)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ByteString -> Query) -> (ByteString, String) -> (Query, String)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ByteString -> Query
Query) ([(ByteString, String)] -> [(Query, String)])
-> (String -> [(ByteString, String)]) -> ReadS Query
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String -> [(ByteString, String)]
forall a. Read a => Int -> ReadS a
readsPrec Int
i

instance IsString Query where
    fromString :: String -> Query
fromString = ByteString -> Query
Query (ByteString -> Query) -> (String -> ByteString) -> String -> Query
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BB.toLazyByteString (Builder -> ByteString)
-> (String -> Builder) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Builder
BB.stringUtf8

-- | A type to wrap a query parameter in to allow for single and multi-valued parameters.
--
-- The behavior of 'Param' can be illustrated by following example:
--
-- @
--    render $ One (MySQLText "hello") = hello
--    render $ Many [MySQLText "hello", MySQLText "world"] = hello, world
--    render $ Many [] = null
-- @
--
-- So you can now write a query like this: @ SELECT * FROM test WHERE _id IN (?, 888) @
-- and use 'Many' 'Param' to fill the hole. There's no equivalent for prepared statement sadly.
--
data Param = One  MySQLValue
           | Many [MySQLValue]

-- | A type that may be used as a single parameter to a SQL query. Inspired from @mysql-simple@.
class QueryParam a where
    render :: a -> Put
    -- ^ Prepare a value for substitution into a query string.

instance QueryParam Param where
    render :: Param -> Put
render (One MySQLValue
x)      = MySQLValue -> Put
putTextField MySQLValue
x
    render (Many [])    = MySQLValue -> Put
putTextField MySQLValue
MySQLNull
    render (Many (MySQLValue
x:[]))= MySQLValue -> Put
putTextField MySQLValue
x
    render (Many (MySQLValue
x:[MySQLValue]
xs))= do MySQLValue -> Put
putTextField MySQLValue
x
                             (MySQLValue -> Put) -> [MySQLValue] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\MySQLValue
f -> Char -> Put
putCharUtf8 Char
',' Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MySQLValue -> Put
putTextField MySQLValue
f) [MySQLValue]
xs

instance QueryParam MySQLValue where
    render :: MySQLValue -> Put
render = MySQLValue -> Put
putTextField

renderParams :: QueryParam p => Query -> [p] -> Query
renderParams :: Query -> [p] -> Query
renderParams (Query ByteString
qry) [p]
params =
    let fragments :: [ByteString]
fragments = Char -> ByteString -> [ByteString]
LC.split Char
'?' ByteString
qry
    in ByteString -> Query
Query (ByteString -> Query) -> (Put -> ByteString) -> Put -> Query
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Put -> ByteString
runPut (Put -> Query) -> Put -> Query
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [p] -> Put
forall a. QueryParam a => [ByteString] -> [a] -> Put
merge [ByteString]
fragments [p]
params
  where
    merge :: [ByteString] -> [a] -> Put
merge [ByteString
x]    []     = ByteString -> Put
putLazyByteString ByteString
x
    merge (ByteString
x:[ByteString]
xs) (a
y:[a]
ys) = ByteString -> Put
putLazyByteString ByteString
x Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> Put
forall a. QueryParam a => a -> Put
render a
y Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [ByteString] -> [a] -> Put
merge [ByteString]
xs [a]
ys
    merge [ByteString]
_     [a]
_       = WrongParamsCount -> Put
forall a e. Exception e => e -> a
throw WrongParamsCount
WrongParamsCount

data WrongParamsCount = WrongParamsCount deriving (Int -> WrongParamsCount -> ShowS
[WrongParamsCount] -> ShowS
WrongParamsCount -> String
(Int -> WrongParamsCount -> ShowS)
-> (WrongParamsCount -> String)
-> ([WrongParamsCount] -> ShowS)
-> Show WrongParamsCount
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WrongParamsCount] -> ShowS
$cshowList :: [WrongParamsCount] -> ShowS
show :: WrongParamsCount -> String
$cshow :: WrongParamsCount -> String
showsPrec :: Int -> WrongParamsCount -> ShowS
$cshowsPrec :: Int -> WrongParamsCount -> ShowS
Show, Typeable)
instance Exception WrongParamsCount