{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts   #-}

{- | Introduces named parameters for @postgresql-simple@ library.
It uses @?@ question mark symbol as the indicator of the named parameter which
is replaced with the standard syntax with question marks.

Check out the example of usage:

@
'queryNamed' dbConnection [sql|
    __SELECT__ *
    __FROM__ users
    __WHERE__ foo = ?foo
      __AND__ bar = ?bar
      __AND__ baz = ?foo
|] [ "foo" '=?' "fooBar"
   , "bar" '=?' "barVar"
   ]
@
-}

module PgNamed
       ( -- * Named data types and smart constructors
         NamedParam (..)
       , Name (..)
       , (=?)

         -- * Errors
       , PgNamedError (..)
       , WithNamedError

         -- * Functions to deal with named parameters
       , extractNames
       , namesToRow

         -- * Database querying functions with named parameters
       , queryNamed
       , queryWithNamed
       , executeNamed
       , executeNamed_

         -- * Internal utils
       , withNamedArgs
       ) where

import Control.Monad (void)
import Control.Monad.Except (MonadError (throwError))
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.Bifunctor (bimap)
import Data.ByteString (ByteString)
import Data.Char (isAlphaNum)
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..), toList)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import GHC.Exts (IsString)

import qualified Data.ByteString.Char8 as BS
import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.FromRow as PG
import qualified Database.PostgreSQL.Simple.Internal as PG
import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified Database.PostgreSQL.Simple.Types as PG


-- | Wrapper over name of the argument.
newtype Name = Name
    { Name -> Text
unName :: Text
    } deriving newtype (Int -> Name -> ShowS
[Name] -> ShowS
Name -> String
(Int -> Name -> ShowS)
-> (Name -> String) -> ([Name] -> ShowS) -> Show Name
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Name] -> ShowS
$cshowList :: [Name] -> ShowS
show :: Name -> String
$cshow :: Name -> String
showsPrec :: Int -> Name -> ShowS
$cshowsPrec :: Int -> Name -> ShowS
Show, Name -> Name -> Bool
(Name -> Name -> Bool) -> (Name -> Name -> Bool) -> Eq Name
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Name -> Name -> Bool
$c/= :: Name -> Name -> Bool
== :: Name -> Name -> Bool
$c== :: Name -> Name -> Bool
Eq, Eq Name
Eq Name
-> (Name -> Name -> Ordering)
-> (Name -> Name -> Bool)
-> (Name -> Name -> Bool)
-> (Name -> Name -> Bool)
-> (Name -> Name -> Bool)
-> (Name -> Name -> Name)
-> (Name -> Name -> Name)
-> Ord Name
Name -> Name -> Bool
Name -> Name -> Ordering
Name -> Name -> Name
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Name -> Name -> Name
$cmin :: Name -> Name -> Name
max :: Name -> Name -> Name
$cmax :: Name -> Name -> Name
>= :: Name -> Name -> Bool
$c>= :: Name -> Name -> Bool
> :: Name -> Name -> Bool
$c> :: Name -> Name -> Bool
<= :: Name -> Name -> Bool
$c<= :: Name -> Name -> Bool
< :: Name -> Name -> Bool
$c< :: Name -> Name -> Bool
compare :: Name -> Name -> Ordering
$ccompare :: Name -> Name -> Ordering
$cp1Ord :: Eq Name
Ord, String -> Name
(String -> Name) -> IsString Name
forall a. (String -> a) -> IsString a
fromString :: String -> Name
$cfromString :: String -> Name
IsString)

-- | Data type to represent each named parameter.
data NamedParam = NamedParam
    { NamedParam -> Name
namedParamName  :: !Name
    , NamedParam -> Action
namedParamParam :: !PG.Action
    } deriving stock (Int -> NamedParam -> ShowS
[NamedParam] -> ShowS
NamedParam -> String
(Int -> NamedParam -> ShowS)
-> (NamedParam -> String)
-> ([NamedParam] -> ShowS)
-> Show NamedParam
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NamedParam] -> ShowS
$cshowList :: [NamedParam] -> ShowS
show :: NamedParam -> String
$cshow :: NamedParam -> String
showsPrec :: Int -> NamedParam -> ShowS
$cshowsPrec :: Int -> NamedParam -> ShowS
Show)

-- | @PostgreSQL@ error type for named parameters.
data PgNamedError
    -- | Named parameter is not specified.
    = PgNamedParam Name
    -- | Query has no names inside but was called with named functions.
    | PgNoNames PG.Query
    -- | Query contains an empty name.
    | PgEmptyName PG.Query
    deriving stock (PgNamedError -> PgNamedError -> Bool
(PgNamedError -> PgNamedError -> Bool)
-> (PgNamedError -> PgNamedError -> Bool) -> Eq PgNamedError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PgNamedError -> PgNamedError -> Bool
$c/= :: PgNamedError -> PgNamedError -> Bool
== :: PgNamedError -> PgNamedError -> Bool
$c== :: PgNamedError -> PgNamedError -> Bool
Eq)


-- | Type alias for monads that can throw errors of the 'PgNamedError' type.
type WithNamedError = MonadError PgNamedError

instance Show PgNamedError where
    show :: PgNamedError -> String
show PgNamedError
e = String
"PostgreSQL named parameter error: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ case PgNamedError
e of
        PgNamedParam Name
n -> String
"Named parameter '" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"' is not specified"
        PgNoNames (PG.Query ByteString
q) ->
            String
"Query has no names but was called with named functions: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
BS.unpack ByteString
q
        PgEmptyName (PG.Query ByteString
q) ->
            String
"Query contains an empty name: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
BS.unpack ByteString
q

-- | Checks whether the 'Name' is in the list and returns its parameter.
lookupName :: Name -> [NamedParam] -> Maybe PG.Action
lookupName :: Name -> [NamedParam] -> Maybe Action
lookupName Name
n = Name -> [(Name, Action)] -> Maybe Action
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
n ([(Name, Action)] -> Maybe Action)
-> ([NamedParam] -> [(Name, Action)])
-> [NamedParam]
-> Maybe Action
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NamedParam -> (Name, Action)) -> [NamedParam] -> [(Name, Action)]
forall a b. (a -> b) -> [a] -> [b]
map (\NamedParam{Action
Name
namedParamParam :: Action
namedParamName :: Name
namedParamParam :: NamedParam -> Action
namedParamName :: NamedParam -> Name
..} -> (Name
namedParamName, Action
namedParamParam))

{- | This function takes query with named parameters specified like this:

@
__SELECT__ name, user __FROM__ users __WHERE__ id = ?id
@

and returns either the error or the query with all names replaced by
question marks @?@ with the list of the names in the order of their appearance.

For example:

>>> extractNames "SELECT * FROM users WHERE foo = ?foo AND bar = ?bar AND baz = ?foo"
Right ("SELECT * FROM users WHERE foo = ? AND bar = ? AND baz = ?","foo" :| ["bar","foo"])

>>> extractNames "SELECT foo FROM my_table WHERE (foo->'bar' ??| ?selectedTags);"
Right ("SELECT foo FROM my_table WHERE (foo->'bar' ?| ?);","selectedTags" :| [])

When the operator is not escaped, it's treated as a named parameter
>>> extractNames "SELECT foo FROM my_table WHERE (foo->'bar' ?| ?selectedTags);"
Left PostgreSQL named parameter error: Query contains an empty name: SELECT foo FROM my_table WHERE (foo->'bar' ?| ?selectedTags);
-}
extractNames
    :: PG.Query
    -> Either PgNamedError (PG.Query, NonEmpty Name)
extractNames :: Query -> Either PgNamedError (Query, NonEmpty Name)
extractNames Query
qr = ByteString -> Either PgNamedError (ByteString, [Name])
go (Query -> ByteString
PG.fromQuery Query
qr) Either PgNamedError (ByteString, [Name])
-> ((ByteString, [Name])
    -> Either PgNamedError (Query, NonEmpty Name))
-> Either PgNamedError (Query, NonEmpty Name)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    (ByteString
_, [])         -> PgNamedError -> Either PgNamedError (Query, NonEmpty Name)
forall a b. a -> Either a b
Left (PgNamedError -> Either PgNamedError (Query, NonEmpty Name))
-> PgNamedError -> Either PgNamedError (Query, NonEmpty Name)
forall a b. (a -> b) -> a -> b
$ Query -> PgNamedError
PgNoNames Query
qr
    (ByteString
q, Name
name:[Name]
names) -> (Query, NonEmpty Name)
-> Either PgNamedError (Query, NonEmpty Name)
forall a b. b -> Either a b
Right (ByteString -> Query
PG.Query ByteString
q, Name
name Name -> [Name] -> NonEmpty Name
forall a. a -> [a] -> NonEmpty a
:| [Name]
names)
  where
    go :: ByteString -> Either PgNamedError (ByteString, [Name])
    go :: ByteString -> Either PgNamedError (ByteString, [Name])
go ByteString
str
        | ByteString -> Bool
BS.null ByteString
str = (ByteString, [Name]) -> Either PgNamedError (ByteString, [Name])
forall a b. b -> Either a b
Right (ByteString
"", [])
        | Bool
otherwise   = let (ByteString
before, ByteString
after) = ByteString -> (ByteString, ByteString)
PG.breakOnSingleQuestionMark ByteString
str in
            case ByteString -> Maybe (Char, ByteString)
BS.uncons ByteString
after of
                Maybe (Char, ByteString)
Nothing -> (ByteString, [Name]) -> Either PgNamedError (ByteString, [Name])
forall a b. b -> Either a b
Right (ByteString
before, [])
                Just (Char
'?', ByteString
nameStart) ->
                    let (ByteString
name, ByteString
remainingQuery) = (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BS.span Char -> Bool
isNameChar ByteString
nameStart
                    in if ByteString -> Bool
BS.null ByteString
name
                           then PgNamedError -> Either PgNamedError (ByteString, [Name])
forall a b. a -> Either a b
Left (PgNamedError -> Either PgNamedError (ByteString, [Name]))
-> PgNamedError -> Either PgNamedError (ByteString, [Name])
forall a b. (a -> b) -> a -> b
$ Query -> PgNamedError
PgEmptyName Query
qr
                           else ((ByteString, [Name]) -> (ByteString, [Name]))
-> Either PgNamedError (ByteString, [Name])
-> Either PgNamedError (ByteString, [Name])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ByteString -> ByteString)
-> ([Name] -> [Name])
-> (ByteString, [Name])
-> (ByteString, [Name])
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ((ByteString
before ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"?") ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (Text -> Name
Name (ByteString -> Text
decodeUtf8 ByteString
name) Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
:))
                                     (ByteString -> Either PgNamedError (ByteString, [Name])
go ByteString
remainingQuery)
                Just (Char, ByteString)
_ -> String -> Either PgNamedError (ByteString, [Name])
forall a. HasCallStack => String -> a
error String
"'break (== '?')' doesn't return string started with the question mark"

    isNameChar :: Char -> Bool
    isNameChar :: Char -> Bool
isNameChar Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'

{- | Returns the list of values to use in query by given list of 'Name's.
Throws 'PgNamedError' if any named parameter is not specified.
-}
namesToRow
    :: forall m . WithNamedError m
    => NonEmpty Name  -- ^ List of the names used in query
    -> [NamedParam]   -- ^ List of the named parameters
    -> m (NonEmpty PG.Action)
namesToRow :: NonEmpty Name -> [NamedParam] -> m (NonEmpty Action)
namesToRow NonEmpty Name
names [NamedParam]
params = (Name -> m Action) -> NonEmpty Name -> m (NonEmpty Action)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Name -> m Action
magicLookup NonEmpty Name
names
  where
    magicLookup :: Name -> m PG.Action
    magicLookup :: Name -> m Action
magicLookup Name
n = case Name -> [NamedParam] -> Maybe Action
lookupName Name
n [NamedParam]
params of
        Just Action
x  -> Action -> m Action
forall (f :: * -> *) a. Applicative f => a -> f a
pure Action
x
        Maybe Action
Nothing -> PgNamedError -> m Action
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PgNamedError -> m Action) -> PgNamedError -> m Action
forall a b. (a -> b) -> a -> b
$ Name -> PgNamedError
PgNamedParam Name
n

{- | Operator to create 'NamedParam's.

>>> "foo" =? (1 :: Int)
NamedParam {namedParamName = "foo", namedParamParam = Plain "1"}

So it can be used in creating the list of the named arguments:

@
'queryNamed' dbConnection [sql|
    __SELECT__ *
    __FROM__ users
    __WHERE__ foo = ?foo
      __AND__ bar = ?bar
      __AND__ baz = ?foo
|] [ "foo" '=?' "fooBar"
   , "bar" '=?' "barVar"
   ]
@
-}
infix 1 =?
(=?) :: (PG.ToField a) => Name -> a -> NamedParam
Name
n =? :: Name -> a -> NamedParam
=? a
a = Name -> Action -> NamedParam
NamedParam Name
n (Action -> NamedParam) -> Action -> NamedParam
forall a b. (a -> b) -> a -> b
$ a -> Action
forall a. ToField a => a -> Action
PG.toField a
a
{-# INLINE (=?) #-}

{- | Queries the database with a given query and named parameters
and expects a list of rows in return.

@
'queryNamed' dbConnection [sql|
    __SELECT__ id
    __FROM__ table
    __WHERE__ foo = ?foo
|] [ "foo" '=?' "bar" ]
@
-}
queryNamed
    :: (MonadIO m, WithNamedError m, PG.FromRow res)
    => PG.Connection  -- ^ Database connection
    -> PG.Query       -- ^ Query with named parameters inside
    -> [NamedParam]   -- ^ The list of named parameters to be used in the query
    -> m [res]        -- ^ Resulting rows
queryNamed :: Connection -> Query -> [NamedParam] -> m [res]
queryNamed Connection
conn Query
qNamed [NamedParam]
params =
    Query -> [NamedParam] -> m (Query, NonEmpty Action)
forall (m :: * -> *).
WithNamedError m =>
Query -> [NamedParam] -> m (Query, NonEmpty Action)
withNamedArgs Query
qNamed [NamedParam]
params m (Query, NonEmpty Action)
-> ((Query, NonEmpty Action) -> m [res]) -> m [res]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Query
q, NonEmpty Action
actions) ->
        IO [res] -> m [res]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [res] -> m [res]) -> IO [res] -> m [res]
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> [Action] -> IO [res]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
conn Query
q (NonEmpty Action -> [Action]
forall a. NonEmpty a -> [a]
toList NonEmpty Action
actions)

{- | Queries the database with a given row parser, 'PG.Query', and named parameters
and expects a list of rows in return.

Sometimes there are multiple ways to parse tuples returned by PostgreSQL into
the same data type. However, it's not possible to implement multiple intances of
the 'PG.FromRow' typeclass (or any other typeclass).

Consider the following data type:

@
__data__ Person = Person
    { personName :: !Text
    , personAge  :: !(Maybe Int)
    }
@

We might want to parse values of the @Person@ data type in two ways:

1. Default by parsing all fields.
2. Parse only name and @age@ to 'Nothing'.

If you want to have multiple instances, you need to create @newtype@ for each
case. However, in some cases it might not be convenient to deal with newtypes
around large data types. So you can implement custom 'PG.RowParser' and use it
with 'queryWithNamed'.

@
'queryWithNamed' rowParser dbConnection [sql|
    __SELECT__ id
    __FROM__ table
    __WHERE__ foo = ?foo
|] [ "foo" '=?' "bar" ]
@
-}
queryWithNamed
    :: (MonadIO m, WithNamedError m)
    => PG.RowParser res -- ^ Custom defined row parser
    -> PG.Connection    -- ^ Database connection
    -> PG.Query         -- ^ Query with named parameters inside
    -> [NamedParam]     -- ^ The list of named parameters to be used in the query
    -> m [res]          -- ^ Resulting rows
queryWithNamed :: RowParser res -> Connection -> Query -> [NamedParam] -> m [res]
queryWithNamed RowParser res
rowParser Connection
conn Query
qNamed [NamedParam]
params =
    Query -> [NamedParam] -> m (Query, NonEmpty Action)
forall (m :: * -> *).
WithNamedError m =>
Query -> [NamedParam] -> m (Query, NonEmpty Action)
withNamedArgs Query
qNamed [NamedParam]
params m (Query, NonEmpty Action)
-> ((Query, NonEmpty Action) -> m [res]) -> m [res]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Query
q, NonEmpty Action
actions) ->
        IO [res] -> m [res]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [res] -> m [res]) -> IO [res] -> m [res]
forall a b. (a -> b) -> a -> b
$ RowParser res -> Connection -> Query -> [Action] -> IO [res]
forall q r.
ToRow q =>
RowParser r -> Connection -> Query -> q -> IO [r]
PG.queryWith RowParser res
rowParser Connection
conn Query
q (NonEmpty Action -> [Action]
forall a. NonEmpty a -> [a]
toList NonEmpty Action
actions)

{- | Modifies the database with a given query and named parameters
and expects a number of the rows affected.

@
'executeNamed' dbConnection [sql|
    __UPDATE__ table
    __SET__ foo = \'bar\'
    __WHERE__ id = ?id
|] [ "id" '=?' someId ]
@
-}
executeNamed
    :: (MonadIO m, WithNamedError m)
    => PG.Connection  -- ^ Database connection
    -> PG.Query       -- ^ Query with named parameters inside
    -> [NamedParam]   -- ^ The list of named parameters to be used in the query
    -> m Int64        -- ^ Number of the rows affected by the given query
executeNamed :: Connection -> Query -> [NamedParam] -> m Int64
executeNamed Connection
conn Query
qNamed [NamedParam]
params =
    Query -> [NamedParam] -> m (Query, NonEmpty Action)
forall (m :: * -> *).
WithNamedError m =>
Query -> [NamedParam] -> m (Query, NonEmpty Action)
withNamedArgs Query
qNamed [NamedParam]
params m (Query, NonEmpty Action)
-> ((Query, NonEmpty Action) -> m Int64) -> m Int64
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Query
q, NonEmpty Action
actions) ->
        IO Int64 -> m Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> m Int64) -> IO Int64 -> m Int64
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> [Action] -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
PG.execute Connection
conn Query
q (NonEmpty Action -> [Action]
forall a. NonEmpty a -> [a]
toList NonEmpty Action
actions)

{- | Same as 'executeNamed' but discard the nubmer of rows affected by the given
query. This function is useful when you're not interested in this number.
-}
executeNamed_
    :: (MonadIO m, WithNamedError m)
    => PG.Connection  -- ^ Database connection
    -> PG.Query       -- ^ Query with named parameters inside
    -> [NamedParam]   -- ^ The list of named parameters to be used in the query
    -> m ()
executeNamed_ :: Connection -> Query -> [NamedParam] -> m ()
executeNamed_ Connection
conn Query
qNamed = m Int64 -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Int64 -> m ())
-> ([NamedParam] -> m Int64) -> [NamedParam] -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Query -> [NamedParam] -> m Int64
forall (m :: * -> *).
(MonadIO m, WithNamedError m) =>
Connection -> Query -> [NamedParam] -> m Int64
executeNamed Connection
conn Query
qNamed
{-# INLINE executeNamed_ #-}

{- | Helper to use named parameters. Use it to implement named wrappers around
functions from @postgresql-simple@ library. If you think that the function is
useful, consider opening feature request to the @postgresql-simple-named@
library:

* https://github.com/Holmusk/postgresql-simple-named/issues
-}
withNamedArgs
    :: WithNamedError m
    => PG.Query
    -> [NamedParam]
    -> m (PG.Query, NonEmpty PG.Action)
withNamedArgs :: Query -> [NamedParam] -> m (Query, NonEmpty Action)
withNamedArgs Query
qNamed [NamedParam]
namedArgs = do
    (Query
q, NonEmpty Name
names) <- case Query -> Either PgNamedError (Query, NonEmpty Name)
extractNames Query
qNamed of
        Left PgNamedError
errType -> PgNamedError -> m (Query, NonEmpty Name)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError PgNamedError
errType
        Right (Query, NonEmpty Name)
r      -> (Query, NonEmpty Name) -> m (Query, NonEmpty Name)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Query, NonEmpty Name)
r
    NonEmpty Action
args <- NonEmpty Name -> [NamedParam] -> m (NonEmpty Action)
forall (m :: * -> *).
WithNamedError m =>
NonEmpty Name -> [NamedParam] -> m (NonEmpty Action)
namesToRow NonEmpty Name
names [NamedParam]
namedArgs
    (Query, NonEmpty Action) -> m (Query, NonEmpty Action)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Query
q, NonEmpty Action
args)