{-# LANGUAGE ScopedTypeVariables #-}

module Database.Persist.Sql.Util
    ( parseEntityValues
    , keyAndEntityColumnNames
    , entityColumnCount
    , isIdField
    , hasNaturalKey
    , hasCompositePrimaryKey
    , dbIdColumns
    , dbIdColumnsEsc
    , dbColumns
    , updateFieldDef
    , updatePersistValue
    , mkUpdateText
    , mkUpdateText'
    , commaSeparated
    , parenWrapped
    , mkInsertValues
    , mkInsertPlaceholders
    ) where

import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Maybe as Maybe
import Data.Text (Text, pack)
import qualified Data.Text as T

import Database.Persist
       ( Entity(Entity)
       , EntityDef
       , EntityField
       , FieldDef(..)
       , FieldNameDB
       , FieldNameHS(FieldNameHS)
       , PersistEntity(..)
       , PersistUpdate(..)
       , PersistValue
       , Update(..)
       , compositeFields
       , entityPrimary
       , fieldDB
       , fieldHaskell
       , fromPersistValues
       , getEntityFields
       , getEntityKeyFields
       , keyAndEntityFields
       , keyFromValues
       , persistFieldDef
       , toPersistValue
       )

import Database.Persist.Sql.Types (Sql)
import Database.Persist.SqlBackend.Internal (SqlBackend(..))

keyAndEntityColumnNames :: EntityDef -> SqlBackend -> NonEmpty Sql
keyAndEntityColumnNames :: EntityDef -> SqlBackend -> NonEmpty Sql
keyAndEntityColumnNames EntityDef
ent SqlBackend
conn =
    (FieldDef -> Sql) -> NonEmpty FieldDef -> NonEmpty Sql
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SqlBackend -> FieldNameDB -> Sql
connEscapeFieldName SqlBackend
conn (FieldNameDB -> Sql)
-> (FieldDef -> FieldNameDB) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> FieldNameDB
fieldDB) (EntityDef -> NonEmpty FieldDef
keyAndEntityFields EntityDef
ent)

entityColumnCount :: EntityDef -> Int
entityColumnCount :: EntityDef -> Int
entityColumnCount EntityDef
e = [FieldDef] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (EntityDef -> [FieldDef]
getEntityFields EntityDef
e)
                    Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if EntityDef -> Bool
hasNaturalKey EntityDef
e then Int
0 else Int
1

-- | Returns 'True' if the entity has a natural key defined with the
-- Primary keyword.
--
-- A natural key is a key that is inherent to the record, and is part of
-- the actual Haskell record. The opposite of a natural key is a "surrogate
-- key", which is not part of the normal domain object. Automatically
-- generated ID columns are the most common surrogate ID, while an email
-- address is a common natural key.
--
-- @
-- User
--     email String
--     name String
--     Primary email
--
-- Person
--     Id   UUID
--     name String
--
-- Follower
--     name String
-- @
--
-- Given these entity definitions, @User@ would return 'True', because the
-- @Primary@ keyword sets the @email@ column to be the primary key. The
-- generated Haskell type would look like this:
--
-- @
-- data User = User
--     { userEmail :: String
--     , userName :: String
--     }
-- @
--
-- @Person@ would be false. While the @Id@ syntax allows you to define
-- a custom ID type for an entity, the @Id@ column is a surrogate key.
--
-- The same is true for @Follower@. The automatically generated
-- autoincremented integer primary key is a surrogate key.
--
-- There's nothing preventing you from defining a @Primary@ definition that
-- refers to a surrogate key. This is totally fine.
--
-- @since 2.11.0
hasNaturalKey :: EntityDef -> Bool
hasNaturalKey :: EntityDef -> Bool
hasNaturalKey =
    Maybe CompositeDef -> Bool
forall a. Maybe a -> Bool
Maybe.isJust (Maybe CompositeDef -> Bool)
-> (EntityDef -> Maybe CompositeDef) -> EntityDef -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntityDef -> Maybe CompositeDef
entityPrimary

-- | Returns 'True' if the provided entity has a custom composite primary
-- key. Composite keys have multiple fields in them.
--
-- @
-- User
--     email String
--     name String
--     Primary userId
--
-- Profile
--     personId PersonId
--     email    String
--     Primary personId email
--
-- Person
--     Id   UUID
--     name String
--
-- Follower
--     name String
-- @
--
-- Given these entity definitions, only @Profile@ would return 'True',
-- because it is the only entity with multiple columns in the primary key.
-- @User@ has a single column natural key. @Person@ has a custom single
-- column surrogate key defined with @Id@. And @Follower@ has a default
-- single column surrogate key.
--
-- @since 2.11.0
hasCompositePrimaryKey :: EntityDef -> Bool
hasCompositePrimaryKey :: EntityDef -> Bool
hasCompositePrimaryKey EntityDef
ed =
    case EntityDef -> Maybe CompositeDef
entityPrimary EntityDef
ed of
        Just CompositeDef
cdef ->
            case CompositeDef -> NonEmpty FieldDef
compositeFields CompositeDef
cdef of
                (FieldDef
_ :| FieldDef
_ : [FieldDef]
_) ->
                    Bool
True
                NonEmpty FieldDef
_ ->
                    Bool
False
        Maybe CompositeDef
Nothing ->
            Bool
False

dbIdColumns :: SqlBackend -> EntityDef -> NonEmpty Text
dbIdColumns :: SqlBackend -> EntityDef -> NonEmpty Sql
dbIdColumns SqlBackend
conn = (FieldNameDB -> Sql) -> EntityDef -> NonEmpty Sql
dbIdColumnsEsc (SqlBackend -> FieldNameDB -> Sql
connEscapeFieldName SqlBackend
conn)

dbIdColumnsEsc :: (FieldNameDB -> Text) -> EntityDef -> NonEmpty Text
dbIdColumnsEsc :: (FieldNameDB -> Sql) -> EntityDef -> NonEmpty Sql
dbIdColumnsEsc FieldNameDB -> Sql
esc EntityDef
t = (FieldDef -> Sql) -> NonEmpty FieldDef -> NonEmpty Sql
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (FieldNameDB -> Sql
esc (FieldNameDB -> Sql)
-> (FieldDef -> FieldNameDB) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> FieldNameDB
fieldDB) (NonEmpty FieldDef -> NonEmpty Sql)
-> NonEmpty FieldDef -> NonEmpty Sql
forall a b. (a -> b) -> a -> b
$ EntityDef -> NonEmpty FieldDef
getEntityKeyFields EntityDef
t

dbColumns :: SqlBackend -> EntityDef -> NonEmpty Text
dbColumns :: SqlBackend -> EntityDef -> NonEmpty Sql
dbColumns SqlBackend
conn =
    (FieldDef -> Sql) -> NonEmpty FieldDef -> NonEmpty Sql
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FieldDef -> Sql
escapeColumn (NonEmpty FieldDef -> NonEmpty Sql)
-> (EntityDef -> NonEmpty FieldDef) -> EntityDef -> NonEmpty Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntityDef -> NonEmpty FieldDef
keyAndEntityFields
  where
    escapeColumn :: FieldDef -> Sql
escapeColumn = SqlBackend -> FieldNameDB -> Sql
connEscapeFieldName SqlBackend
conn (FieldNameDB -> Sql)
-> (FieldDef -> FieldNameDB) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> FieldNameDB
fieldDB

parseEntityValues :: PersistEntity record
                  => EntityDef -> [PersistValue] -> Either Text (Entity record)
parseEntityValues :: EntityDef -> [PersistValue] -> Either Sql (Entity record)
parseEntityValues EntityDef
t [PersistValue]
vals =
    case EntityDef -> Maybe CompositeDef
entityPrimary EntityDef
t of
      Just CompositeDef
pdef ->
            let pks :: NonEmpty FieldNameHS
pks = (FieldDef -> FieldNameHS)
-> NonEmpty FieldDef -> NonEmpty FieldNameHS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FieldDef -> FieldNameHS
fieldHaskell (NonEmpty FieldDef -> NonEmpty FieldNameHS)
-> NonEmpty FieldDef -> NonEmpty FieldNameHS
forall a b. (a -> b) -> a -> b
$ CompositeDef -> NonEmpty FieldDef
compositeFields CompositeDef
pdef
                keyvals :: [PersistValue]
keyvals = ((FieldNameHS, PersistValue) -> PersistValue)
-> [(FieldNameHS, PersistValue)] -> [PersistValue]
forall a b. (a -> b) -> [a] -> [b]
map (FieldNameHS, PersistValue) -> PersistValue
forall a b. (a, b) -> b
snd ([(FieldNameHS, PersistValue)] -> [PersistValue])
-> ([(FieldNameHS, PersistValue)] -> [(FieldNameHS, PersistValue)])
-> [(FieldNameHS, PersistValue)]
-> [PersistValue]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((FieldNameHS, PersistValue) -> Bool)
-> [(FieldNameHS, PersistValue)] -> [(FieldNameHS, PersistValue)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((FieldNameHS -> NonEmpty FieldNameHS -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` NonEmpty FieldNameHS
pks) (FieldNameHS -> Bool)
-> ((FieldNameHS, PersistValue) -> FieldNameHS)
-> (FieldNameHS, PersistValue)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldNameHS, PersistValue) -> FieldNameHS
forall a b. (a, b) -> a
fst)
                        ([(FieldNameHS, PersistValue)] -> [PersistValue])
-> [(FieldNameHS, PersistValue)] -> [PersistValue]
forall a b. (a -> b) -> a -> b
$ [FieldNameHS] -> [PersistValue] -> [(FieldNameHS, PersistValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((FieldDef -> FieldNameHS) -> [FieldDef] -> [FieldNameHS]
forall a b. (a -> b) -> [a] -> [b]
map FieldDef -> FieldNameHS
fieldHaskell ([FieldDef] -> [FieldNameHS]) -> [FieldDef] -> [FieldNameHS]
forall a b. (a -> b) -> a -> b
$ EntityDef -> [FieldDef]
getEntityFields EntityDef
t) [PersistValue]
vals
            in [PersistValue] -> [PersistValue] -> Either Sql (Entity record)
forall record.
PersistEntity record =>
[PersistValue] -> [PersistValue] -> Either Sql (Entity record)
fromPersistValuesComposite' [PersistValue]
keyvals [PersistValue]
vals
      Maybe CompositeDef
Nothing -> [PersistValue] -> Either Sql (Entity record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Entity record)
fromPersistValues' [PersistValue]
vals
  where
    fromPersistValues' :: [PersistValue] -> Either Sql (Entity record)
fromPersistValues' (PersistValue
kpv:[PersistValue]
xs) = -- oracle returns Double
        case [PersistValue] -> Either Sql record
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql record
fromPersistValues [PersistValue]
xs of
            Left Sql
e -> Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left Sql
e
            Right record
xs' ->
                case [PersistValue] -> Either Sql (Key record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Key record)
keyFromValues [PersistValue
kpv] of
                    Left Sql
_ -> [Char] -> Either Sql (Entity record)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Either Sql (Entity record))
-> [Char] -> Either Sql (Entity record)
forall a b. (a -> b) -> a -> b
$ [Char]
"fromPersistValues': keyFromValues failed on " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PersistValue -> [Char]
forall a. Show a => a -> [Char]
show PersistValue
kpv
                    Right Key record
k -> Entity record -> Either Sql (Entity record)
forall a b. b -> Either a b
Right (Key record -> record -> Entity record
forall record. Key record -> record -> Entity record
Entity Key record
k record
xs')


    fromPersistValues' [PersistValue]
xs = Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left (Sql -> Either Sql (Entity record))
-> Sql -> Either Sql (Entity record)
forall a b. (a -> b) -> a -> b
$ [Char] -> Sql
pack ([Char]
"error in fromPersistValues' xs=" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [PersistValue] -> [Char]
forall a. Show a => a -> [Char]
show [PersistValue]
xs)

    fromPersistValuesComposite' :: [PersistValue] -> [PersistValue] -> Either Sql (Entity record)
fromPersistValuesComposite' [PersistValue]
keyvals [PersistValue]
xs =
        case [PersistValue] -> Either Sql record
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql record
fromPersistValues [PersistValue]
xs of
            Left Sql
e -> Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left Sql
e
            Right record
xs' -> case [PersistValue] -> Either Sql (Key record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Key record)
keyFromValues [PersistValue]
keyvals of
                Left Sql
err -> [Char] -> Either Sql (Entity record)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Either Sql (Entity record))
-> [Char] -> Either Sql (Entity record)
forall a b. (a -> b) -> a -> b
$ [Char]
"fromPersistValuesComposite': keyFromValues failed with error: "
                    [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Sql -> [Char]
T.unpack Sql
err
                Right Key record
key -> Entity record -> Either Sql (Entity record)
forall a b. b -> Either a b
Right (Key record -> record -> Entity record
forall record. Key record -> record -> Entity record
Entity Key record
key record
xs')


isIdField
    :: forall record typ. (PersistEntity record)
    => EntityField record typ
    -> Bool
isIdField :: EntityField record typ -> Bool
isIdField EntityField record typ
f = FieldDef -> FieldNameHS
fieldHaskell (EntityField record typ -> FieldDef
forall record typ.
PersistEntity record =>
EntityField record typ -> FieldDef
persistFieldDef EntityField record typ
f) FieldNameHS -> FieldNameHS -> Bool
forall a. Eq a => a -> a -> Bool
== Sql -> FieldNameHS
FieldNameHS Sql
"Id"

-- | Gets the 'FieldDef' for an 'Update'.
updateFieldDef :: PersistEntity v => Update v -> FieldDef
updateFieldDef :: Update v -> FieldDef
updateFieldDef (Update EntityField v typ
f typ
_ PersistUpdate
_) = EntityField v typ -> FieldDef
forall record typ.
PersistEntity record =>
EntityField record typ -> FieldDef
persistFieldDef EntityField v typ
f
updateFieldDef BackendUpdate {} = [Char] -> FieldDef
forall a. HasCallStack => [Char] -> a
error [Char]
"updateFieldDef: did not expect BackendUpdate"

updatePersistValue :: Update v -> PersistValue
updatePersistValue :: Update v -> PersistValue
updatePersistValue (Update EntityField v typ
_ typ
v PersistUpdate
_) = typ -> PersistValue
forall a. PersistField a => a -> PersistValue
toPersistValue typ
v
updatePersistValue (BackendUpdate{}) =
    [Char] -> PersistValue
forall a. HasCallStack => [Char] -> a
error [Char]
"updatePersistValue: did not expect BackendUpdate"

commaSeparated :: [Text] -> Text
commaSeparated :: [Sql] -> Sql
commaSeparated = Sql -> [Sql] -> Sql
T.intercalate Sql
", "

mkUpdateText :: PersistEntity record => SqlBackend -> Update record -> Text
mkUpdateText :: SqlBackend -> Update record -> Sql
mkUpdateText SqlBackend
conn = (FieldNameDB -> Sql) -> (Sql -> Sql) -> Update record -> Sql
forall record.
PersistEntity record =>
(FieldNameDB -> Sql) -> (Sql -> Sql) -> Update record -> Sql
mkUpdateText' (SqlBackend -> FieldNameDB -> Sql
connEscapeFieldName SqlBackend
conn) Sql -> Sql
forall a. a -> a
id

-- TODO: incorporate the table names into a sum type
mkUpdateText' :: PersistEntity record => (FieldNameDB -> Text) -> (Text -> Text) -> Update record -> Text
mkUpdateText' :: (FieldNameDB -> Sql) -> (Sql -> Sql) -> Update record -> Sql
mkUpdateText' FieldNameDB -> Sql
escapeName Sql -> Sql
refColumn Update record
x =
  case Update record -> PersistUpdate
forall record. Update record -> PersistUpdate
updateUpdate Update record
x of
    PersistUpdate
Assign -> Sql
n Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
"=?"
    PersistUpdate
Add -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"+?"]
    PersistUpdate
Subtract -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"-?"]
    PersistUpdate
Multiply -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"*?"]
    PersistUpdate
Divide -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"/?"]
    BackendSpecificUpdate Sql
up ->
      [Char] -> Sql
forall a. HasCallStack => [Char] -> a
error ([Char] -> Sql) -> (Sql -> [Char]) -> Sql -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sql -> [Char]
T.unpack (Sql -> Sql) -> Sql -> Sql
forall a b. (a -> b) -> a -> b
$ Sql
"mkUpdateText: BackendSpecificUpdate " Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
up Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
" not supported"
  where
    n :: Sql
n = FieldNameDB -> Sql
escapeName (FieldNameDB -> Sql)
-> (Update record -> FieldNameDB) -> Update record -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> FieldNameDB
fieldDB (FieldDef -> FieldNameDB)
-> (Update record -> FieldDef) -> Update record -> FieldNameDB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Update record -> FieldDef
forall v. PersistEntity v => Update v -> FieldDef
updateFieldDef (Update record -> Sql) -> Update record -> Sql
forall a b. (a -> b) -> a -> b
$ Update record
x

parenWrapped :: Text -> Text
parenWrapped :: Sql -> Sql
parenWrapped Sql
t = [Sql] -> Sql
T.concat [Sql
"(", Sql
t, Sql
")"]

-- | Make a list 'PersistValue' suitable for database inserts. Pairs nicely
-- with the function 'mkInsertPlaceholders'.
--
-- Does not include generated columns.
--
-- @since 2.11.0.0
mkInsertValues
    :: PersistEntity rec
    => rec
    -> [PersistValue]
mkInsertValues :: rec -> [PersistValue]
mkInsertValues rec
entity =
    [Maybe PersistValue] -> [PersistValue]
forall a. [Maybe a] -> [a]
Maybe.catMaybes
        ([Maybe PersistValue] -> [PersistValue])
-> ([SomePersistField] -> [Maybe PersistValue])
-> [SomePersistField]
-> [PersistValue]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldDef -> PersistValue -> Maybe PersistValue)
-> [FieldDef] -> [PersistValue] -> [Maybe PersistValue]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith FieldDef -> PersistValue -> Maybe PersistValue
forall a. FieldDef -> a -> Maybe a
redactGeneratedCol (EntityDef -> [FieldDef]
getEntityFields (EntityDef -> [FieldDef])
-> (Maybe rec -> EntityDef) -> Maybe rec -> [FieldDef]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe rec -> EntityDef
forall record (proxy :: * -> *).
PersistEntity record =>
proxy record -> EntityDef
entityDef (Maybe rec -> [FieldDef]) -> Maybe rec -> [FieldDef]
forall a b. (a -> b) -> a -> b
$ rec -> Maybe rec
forall a. a -> Maybe a
Just rec
entity)
        ([PersistValue] -> [Maybe PersistValue])
-> ([SomePersistField] -> [PersistValue])
-> [SomePersistField]
-> [Maybe PersistValue]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomePersistField -> PersistValue)
-> [SomePersistField] -> [PersistValue]
forall a b. (a -> b) -> [a] -> [b]
map SomePersistField -> PersistValue
forall a. PersistField a => a -> PersistValue
toPersistValue
        ([SomePersistField] -> [PersistValue])
-> [SomePersistField] -> [PersistValue]
forall a b. (a -> b) -> a -> b
$ rec -> [SomePersistField]
forall record. PersistEntity record => record -> [SomePersistField]
toPersistFields rec
entity
  where
    redactGeneratedCol :: FieldDef -> a -> Maybe a
redactGeneratedCol FieldDef
fd a
pv = case FieldDef -> Maybe Sql
fieldGenerated FieldDef
fd of
        Maybe Sql
Nothing ->
            a -> Maybe a
forall a. a -> Maybe a
Just a
pv
        Just Sql
_ ->
            Maybe a
forall a. Maybe a
Nothing

-- | Returns a list of escaped field names and @"?"@ placeholder values for
-- performing inserts. This does not include generated columns.
--
-- Does not include generated columns.
--
-- @since 2.11.0.0
mkInsertPlaceholders
    :: EntityDef
    -> (FieldNameDB -> Text)
    -- ^ An `escape` function
    -> [(Text, Text)]
mkInsertPlaceholders :: EntityDef -> (FieldNameDB -> Sql) -> [(Sql, Sql)]
mkInsertPlaceholders EntityDef
ed FieldNameDB -> Sql
escape =
    (FieldDef -> Maybe (Sql, Sql)) -> [FieldDef] -> [(Sql, Sql)]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe FieldDef -> Maybe (Sql, Sql)
redactGeneratedCol (EntityDef -> [FieldDef]
getEntityFields EntityDef
ed)
  where
    redactGeneratedCol :: FieldDef -> Maybe (Sql, Sql)
redactGeneratedCol FieldDef
fd = case FieldDef -> Maybe Sql
fieldGenerated FieldDef
fd of
        Maybe Sql
Nothing ->
            (Sql, Sql) -> Maybe (Sql, Sql)
forall a. a -> Maybe a
Just (FieldNameDB -> Sql
escape (FieldDef -> FieldNameDB
fieldDB FieldDef
fd), Sql
"?")
        Just Sql
_ ->
            Maybe (Sql, Sql)
forall a. Maybe a
Nothing