{-# LANGUAGE TypeApplications #-}
module Database.PostgreSQL.PQTypes.SQL (
    SQL
  , mkSQL
  , sqlParam
  , (<?>)
  , isSqlEmpty
  ) where

import Control.Concurrent.MVar
import Data.Monoid
import Data.String
import Foreign.Marshal.Alloc
import TextShow
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Foldable as F
import qualified Data.Semigroup as SG
import qualified Data.Sequence as S
import qualified Data.Text as T
import qualified Data.Text.Encoding as T

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Format
import Database.PostgreSQL.PQTypes.Internal.C.Put
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.ToSQL

data SqlChunk where
  SqlString :: !T.Text -> SqlChunk
  SqlParam  :: forall t. (Show t, ToSQL t) => !t -> SqlChunk

-- | Primary SQL type that supports efficient
-- concatenation and variable number of parameters.
newtype SQL = SQL (S.Seq SqlChunk)

unSQL :: SQL -> [SqlChunk]
unSQL :: SQL -> [SqlChunk]
unSQL (SQL Seq SqlChunk
chunks) = forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList Seq SqlChunk
chunks

----------------------------------------

-- | Construct 'SQL' from 'String'.
instance IsString SQL where
  fromString :: String -> SQL
fromString = Text -> SQL
mkSQL forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

instance IsSQL SQL where
  withSQL :: forall r.
SQL -> ParamAllocator -> (Ptr PGparam -> CString -> IO r) -> IO r
withSQL SQL
sql pa :: ParamAllocator
pa@(ParamAllocator forall r. (Ptr PGparam -> IO r) -> IO r
allocParam) Ptr PGparam -> CString -> IO r
execute = do
    forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr PGerror
err -> forall r. (Ptr PGparam -> IO r) -> IO r
allocParam forall a b. (a -> b) -> a -> b
$ \Ptr PGparam
param -> do
      MVar Int
nums <- forall a. a -> IO (MVar a)
newMVar (Int
1::Int)
      Text
query <- [Text] -> Text
T.concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Ptr PGparam -> Ptr PGerror -> MVar Int -> SqlChunk -> IO Text
f Ptr PGparam
param Ptr PGerror
err MVar Int
nums) (SQL -> [SqlChunk]
unSQL SQL
sql)
      forall a. ByteString -> (CString -> IO a) -> IO a
BS.useAsCString (Text -> ByteString
T.encodeUtf8 Text
query) (Ptr PGparam -> CString -> IO r
execute Ptr PGparam
param)
    where
      f :: Ptr PGparam -> Ptr PGerror -> MVar Int -> SqlChunk -> IO Text
f Ptr PGparam
param Ptr PGerror
err MVar Int
nums SqlChunk
chunk = case SqlChunk
chunk of
        SqlString Text
s -> forall (m :: * -> *) a. Monad m => a -> m a
return Text
s
        SqlParam (t
v::t) -> forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL t
v ParamAllocator
pa forall a b. (a -> b) -> a -> b
$ \Ptr (PQDest t)
base ->
          forall a. ByteString -> (CString -> IO a) -> IO a
BS.unsafeUseAsCString (forall t. PQFormat t => ByteString
pqFormat0 @t) forall a b. (a -> b) -> a -> b
$ \CString
fmt -> do
            Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes Ptr PGerror
err String
"withSQL (SQL)" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall t. Ptr PGparam -> Ptr PGerror -> CString -> Ptr t -> IO CInt
c_PQputf1 Ptr PGparam
param Ptr PGerror
err CString
fmt Ptr (PQDest t)
base
            forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar Int
nums forall a b. (a -> b) -> a -> b
$ \Int
n -> forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. (, Text
"$" forall a. Semigroup a => a -> a -> a
<> forall a. TextShow a => a -> Text
showt Int
n) forall a b. (a -> b) -> a -> b
$! Int
nforall a. Num a => a -> a -> a
+Int
1

instance SG.Semigroup SQL where
  SQL Seq SqlChunk
a <> :: SQL -> SQL -> SQL
<> SQL Seq SqlChunk
b = Seq SqlChunk -> SQL
SQL (Seq SqlChunk
a forall a. Seq a -> Seq a -> Seq a
S.>< Seq SqlChunk
b)

instance Monoid SQL where
  mempty :: SQL
mempty = Text -> SQL
mkSQL Text
T.empty
  mappend :: SQL -> SQL -> SQL
mappend = forall a. Semigroup a => a -> a -> a
(SG.<>)

instance Show SQL where
  showsPrec :: Int -> SQL -> ShowS
showsPrec Int
n SQL
sql = (String
"SQL " forall a. [a] -> [a] -> [a]
++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Show a => Int -> a -> ShowS
showsPrec Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SqlChunk -> String
conv forall b c a. (b -> c) -> (a -> b) -> a -> c
. SQL -> [SqlChunk]
unSQL forall a b. (a -> b) -> a -> b
$ SQL
sql)
    where
      conv :: SqlChunk -> String
conv (SqlString Text
s) = Text -> String
T.unpack Text
s
      conv (SqlParam t
v) = String
"<" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show t
v forall a. [a] -> [a] -> [a]
++ String
">"

----------------------------------------

-- | Convert a 'Text' SQL string to the 'SQL' type.
mkSQL :: T.Text -> SQL
mkSQL :: Text -> SQL
mkSQL = Seq SqlChunk -> SQL
SQL forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Seq a
S.singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> SqlChunk
SqlString

-- | Embed parameter value inside 'SQL'.
sqlParam :: (Show t, ToSQL t) => t -> SQL
sqlParam :: forall t. (Show t, ToSQL t) => t -> SQL
sqlParam = Seq SqlChunk -> SQL
SQL forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Seq a
S.singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. (Show t, ToSQL t) => t -> SqlChunk
SqlParam

-- | Embed parameter value inside existing 'SQL'. Example:
--
-- > f :: Int32 -> String -> SQL
-- > f idx name = "SELECT foo FROM bar WHERE id =" <?> idx <+> "AND name =" <?> name
--
(<?>) :: (Show t, ToSQL t) => SQL -> t -> SQL
SQL
s <?> :: forall t. (Show t, ToSQL t) => SQL -> t -> SQL
<?> t
v = SQL
s forall m. (IsString m, Monoid m) => m -> m -> m
<+> forall t. (Show t, ToSQL t) => t -> SQL
sqlParam t
v
infixr 7 <?>

----------------------------------------

-- | Test whether an 'SQL' is empty.
isSqlEmpty :: SQL -> Bool
isSqlEmpty :: SQL -> Bool
isSqlEmpty (SQL Seq SqlChunk
chunks) = All -> Bool
getAll forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
F.foldMap (Bool -> All
All forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlChunk -> Bool
cmp) Seq SqlChunk
chunks
  where
    cmp :: SqlChunk -> Bool
cmp (SqlString Text
s) = Text
s forall a. Eq a => a -> a -> Bool
== Text
T.empty
    cmp (SqlParam t
_)  = Bool
False