module Database.Persist.Sql.Util (
    parseEntityValues
  , entityColumnNames
  , keyAndEntityColumnNames
  , entityColumnCount
  , isIdField
  , hasCompositeKey
  , dbIdColumns
  , dbIdColumnsEsc
  , dbColumns
  , updateFieldDef
  , updatePersistValue
  , mkUpdateText
  , mkUpdateText'
  , commaSeparated
  , parenWrapped
) where

import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Text (Text, pack)
import qualified Data.Text as T

import Database.Persist (
    Entity(Entity), EntityDef, EntityField, HaskellName(HaskellName)
  , PersistEntity, PersistValue
  , keyFromValues, fromPersistValues, fieldDB, entityId, entityPrimary
  , entityFields, entityKeyFields, fieldHaskell, compositeFields, persistFieldDef
  , keyAndEntityFields, toPersistValue, DBName, Update(..), PersistUpdate(..)
  , FieldDef
  )
import Database.Persist.Sql.Types (Sql, SqlBackend, connEscapeName)

entityColumnNames :: EntityDef -> SqlBackend -> [Sql]
entityColumnNames :: EntityDef -> SqlBackend -> [Sql]
entityColumnNames EntityDef
ent SqlBackend
conn =
     (if EntityDef -> Bool
hasCompositeKey EntityDef
ent
      then [] else [SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn (DBName -> Sql) -> DBName -> Sql
forall a b. (a -> b) -> a -> b
$ FieldDef -> DBName
fieldDB (EntityDef -> FieldDef
entityId EntityDef
ent)])
  [Sql] -> [Sql] -> [Sql]
forall a. Semigroup a => a -> a -> a
<> (FieldDef -> Sql) -> [FieldDef] -> [Sql]
forall a b. (a -> b) -> [a] -> [b]
map (SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn (DBName -> Sql) -> (FieldDef -> DBName) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> DBName
fieldDB) (EntityDef -> [FieldDef]
entityFields EntityDef
ent)

keyAndEntityColumnNames :: EntityDef -> SqlBackend -> [Sql]
keyAndEntityColumnNames :: EntityDef -> SqlBackend -> [Sql]
keyAndEntityColumnNames EntityDef
ent SqlBackend
conn = (FieldDef -> Sql) -> [FieldDef] -> [Sql]
forall a b. (a -> b) -> [a] -> [b]
map (SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn (DBName -> Sql) -> (FieldDef -> DBName) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> DBName
fieldDB) (EntityDef -> [FieldDef]
keyAndEntityFields EntityDef
ent)

entityColumnCount :: EntityDef -> Int
entityColumnCount :: EntityDef -> Int
entityColumnCount EntityDef
e = [FieldDef] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (EntityDef -> [FieldDef]
entityFields EntityDef
e)
                    Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if EntityDef -> Bool
hasCompositeKey EntityDef
e then Int
0 else Int
1

hasCompositeKey :: EntityDef -> Bool
hasCompositeKey :: EntityDef -> Bool
hasCompositeKey = Maybe CompositeDef -> Bool
forall a. Maybe a -> Bool
isJust (Maybe CompositeDef -> Bool)
-> (EntityDef -> Maybe CompositeDef) -> EntityDef -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntityDef -> Maybe CompositeDef
entityPrimary

dbIdColumns :: SqlBackend -> EntityDef -> [Text]
dbIdColumns :: SqlBackend -> EntityDef -> [Sql]
dbIdColumns SqlBackend
conn = (DBName -> Sql) -> EntityDef -> [Sql]
dbIdColumnsEsc (SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn)

dbIdColumnsEsc :: (DBName -> Text) -> EntityDef -> [Text]
dbIdColumnsEsc :: (DBName -> Sql) -> EntityDef -> [Sql]
dbIdColumnsEsc DBName -> Sql
esc EntityDef
t = (FieldDef -> Sql) -> [FieldDef] -> [Sql]
forall a b. (a -> b) -> [a] -> [b]
map (DBName -> Sql
esc (DBName -> Sql) -> (FieldDef -> DBName) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> DBName
fieldDB) ([FieldDef] -> [Sql]) -> [FieldDef] -> [Sql]
forall a b. (a -> b) -> a -> b
$ EntityDef -> [FieldDef]
entityKeyFields EntityDef
t

dbColumns :: SqlBackend -> EntityDef -> [Text]
dbColumns :: SqlBackend -> EntityDef -> [Sql]
dbColumns SqlBackend
conn EntityDef
t = case EntityDef -> Maybe CompositeDef
entityPrimary EntityDef
t of
    Just CompositeDef
_  -> [Sql]
flds
    Maybe CompositeDef
Nothing -> FieldDef -> Sql
escapeDB (EntityDef -> FieldDef
entityId EntityDef
t) Sql -> [Sql] -> [Sql]
forall a. a -> [a] -> [a]
: [Sql]
flds
  where
    escapeDB :: FieldDef -> Sql
escapeDB = SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn (DBName -> Sql) -> (FieldDef -> DBName) -> FieldDef -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> DBName
fieldDB
    flds :: [Sql]
flds = (FieldDef -> Sql) -> [FieldDef] -> [Sql]
forall a b. (a -> b) -> [a] -> [b]
map FieldDef -> Sql
escapeDB (EntityDef -> [FieldDef]
entityFields EntityDef
t)

parseEntityValues :: PersistEntity record
                  => EntityDef -> [PersistValue] -> Either Text (Entity record)
parseEntityValues :: EntityDef -> [PersistValue] -> Either Sql (Entity record)
parseEntityValues EntityDef
t [PersistValue]
vals =
    case EntityDef -> Maybe CompositeDef
entityPrimary EntityDef
t of
      Just CompositeDef
pdef ->
            let pks :: [HaskellName]
pks = (FieldDef -> HaskellName) -> [FieldDef] -> [HaskellName]
forall a b. (a -> b) -> [a] -> [b]
map FieldDef -> HaskellName
fieldHaskell ([FieldDef] -> [HaskellName]) -> [FieldDef] -> [HaskellName]
forall a b. (a -> b) -> a -> b
$ CompositeDef -> [FieldDef]
compositeFields CompositeDef
pdef
                keyvals :: [PersistValue]
keyvals = ((HaskellName, PersistValue) -> PersistValue)
-> [(HaskellName, PersistValue)] -> [PersistValue]
forall a b. (a -> b) -> [a] -> [b]
map (HaskellName, PersistValue) -> PersistValue
forall a b. (a, b) -> b
snd ([(HaskellName, PersistValue)] -> [PersistValue])
-> ([(HaskellName, PersistValue)] -> [(HaskellName, PersistValue)])
-> [(HaskellName, PersistValue)]
-> [PersistValue]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HaskellName, PersistValue) -> Bool)
-> [(HaskellName, PersistValue)] -> [(HaskellName, PersistValue)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HaskellName -> [HaskellName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HaskellName]
pks) (HaskellName -> Bool)
-> ((HaskellName, PersistValue) -> HaskellName)
-> (HaskellName, PersistValue)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HaskellName, PersistValue) -> HaskellName
forall a b. (a, b) -> a
fst)
                        ([(HaskellName, PersistValue)] -> [PersistValue])
-> [(HaskellName, PersistValue)] -> [PersistValue]
forall a b. (a -> b) -> a -> b
$ [HaskellName] -> [PersistValue] -> [(HaskellName, PersistValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((FieldDef -> HaskellName) -> [FieldDef] -> [HaskellName]
forall a b. (a -> b) -> [a] -> [b]
map FieldDef -> HaskellName
fieldHaskell ([FieldDef] -> [HaskellName]) -> [FieldDef] -> [HaskellName]
forall a b. (a -> b) -> a -> b
$ EntityDef -> [FieldDef]
entityFields EntityDef
t) [PersistValue]
vals
            in [PersistValue] -> [PersistValue] -> Either Sql (Entity record)
forall record.
PersistEntity record =>
[PersistValue] -> [PersistValue] -> Either Sql (Entity record)
fromPersistValuesComposite' [PersistValue]
keyvals [PersistValue]
vals
      Maybe CompositeDef
Nothing -> [PersistValue] -> Either Sql (Entity record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Entity record)
fromPersistValues' [PersistValue]
vals
  where
    fromPersistValues' :: [PersistValue] -> Either Sql (Entity record)
fromPersistValues' (PersistValue
kpv:[PersistValue]
xs) = -- oracle returns Double
        case [PersistValue] -> Either Sql record
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql record
fromPersistValues [PersistValue]
xs of
            Left Sql
e -> Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left Sql
e
            Right record
xs' ->
                case [PersistValue] -> Either Sql (Key record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Key record)
keyFromValues [PersistValue
kpv] of
                    Left Sql
_ -> [Char] -> Either Sql (Entity record)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Either Sql (Entity record))
-> [Char] -> Either Sql (Entity record)
forall a b. (a -> b) -> a -> b
$ [Char]
"fromPersistValues': keyFromValues failed on " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PersistValue -> [Char]
forall a. Show a => a -> [Char]
show PersistValue
kpv
                    Right Key record
k -> Entity record -> Either Sql (Entity record)
forall a b. b -> Either a b
Right (Key record -> record -> Entity record
forall record. Key record -> record -> Entity record
Entity Key record
k record
xs')


    fromPersistValues' [PersistValue]
xs = Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left (Sql -> Either Sql (Entity record))
-> Sql -> Either Sql (Entity record)
forall a b. (a -> b) -> a -> b
$ [Char] -> Sql
pack ([Char]
"error in fromPersistValues' xs=" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [PersistValue] -> [Char]
forall a. Show a => a -> [Char]
show [PersistValue]
xs)

    fromPersistValuesComposite' :: [PersistValue] -> [PersistValue] -> Either Sql (Entity record)
fromPersistValuesComposite' [PersistValue]
keyvals [PersistValue]
xs =
        case [PersistValue] -> Either Sql record
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql record
fromPersistValues [PersistValue]
xs of
            Left Sql
e -> Sql -> Either Sql (Entity record)
forall a b. a -> Either a b
Left Sql
e
            Right record
xs' -> case [PersistValue] -> Either Sql (Key record)
forall record.
PersistEntity record =>
[PersistValue] -> Either Sql (Key record)
keyFromValues [PersistValue]
keyvals of
                Left Sql
_ -> [Char] -> Either Sql (Entity record)
forall a. HasCallStack => [Char] -> a
error [Char]
"fromPersistValuesComposite': keyFromValues failed"
                Right Key record
key -> Entity record -> Either Sql (Entity record)
forall a b. b -> Either a b
Right (Key record -> record -> Entity record
forall record. Key record -> record -> Entity record
Entity Key record
key record
xs')


isIdField :: PersistEntity record => EntityField record typ -> Bool
isIdField :: EntityField record typ -> Bool
isIdField EntityField record typ
f = FieldDef -> HaskellName
fieldHaskell (EntityField record typ -> FieldDef
forall record typ.
PersistEntity record =>
EntityField record typ -> FieldDef
persistFieldDef EntityField record typ
f) HaskellName -> HaskellName -> Bool
forall a. Eq a => a -> a -> Bool
== Sql -> HaskellName
HaskellName Sql
"Id"

-- | Gets the 'FieldDef' for an 'Update'.
updateFieldDef :: PersistEntity v => Update v -> FieldDef
updateFieldDef :: Update v -> FieldDef
updateFieldDef (Update EntityField v typ
f typ
_ PersistUpdate
_) = EntityField v typ -> FieldDef
forall record typ.
PersistEntity record =>
EntityField record typ -> FieldDef
persistFieldDef EntityField v typ
f
updateFieldDef BackendUpdate {} = [Char] -> FieldDef
forall a. HasCallStack => [Char] -> a
error [Char]
"updateFieldDef: did not expect BackendUpdate"

updatePersistValue :: Update v -> PersistValue
updatePersistValue :: Update v -> PersistValue
updatePersistValue (Update EntityField v typ
_ typ
v PersistUpdate
_) = typ -> PersistValue
forall a. PersistField a => a -> PersistValue
toPersistValue typ
v
updatePersistValue (BackendUpdate{}) =
    [Char] -> PersistValue
forall a. HasCallStack => [Char] -> a
error [Char]
"updatePersistValue: did not expect BackendUpdate"

commaSeparated :: [Text] -> Text
commaSeparated :: [Sql] -> Sql
commaSeparated = Sql -> [Sql] -> Sql
T.intercalate Sql
", "

mkUpdateText :: PersistEntity record => SqlBackend -> Update record -> Text
mkUpdateText :: SqlBackend -> Update record -> Sql
mkUpdateText SqlBackend
conn = (DBName -> Sql) -> (Sql -> Sql) -> Update record -> Sql
forall record.
PersistEntity record =>
(DBName -> Sql) -> (Sql -> Sql) -> Update record -> Sql
mkUpdateText' (SqlBackend -> DBName -> Sql
connEscapeName SqlBackend
conn) Sql -> Sql
forall a. a -> a
id

mkUpdateText' :: PersistEntity record => (DBName -> Text) -> (Text -> Text) -> Update record -> Text
mkUpdateText' :: (DBName -> Sql) -> (Sql -> Sql) -> Update record -> Sql
mkUpdateText' DBName -> Sql
escapeName Sql -> Sql
refColumn Update record
x =
  case Update record -> PersistUpdate
forall record. Update record -> PersistUpdate
updateUpdate Update record
x of
    PersistUpdate
Assign -> Sql
n Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
"=?"
    PersistUpdate
Add -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"+?"]
    PersistUpdate
Subtract -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"-?"]
    PersistUpdate
Multiply -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"*?"]
    PersistUpdate
Divide -> [Sql] -> Sql
T.concat [Sql
n, Sql
"=", Sql -> Sql
refColumn Sql
n, Sql
"/?"]
    BackendSpecificUpdate Sql
up ->
      [Char] -> Sql
forall a. HasCallStack => [Char] -> a
error ([Char] -> Sql) -> (Sql -> [Char]) -> Sql -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sql -> [Char]
T.unpack (Sql -> Sql) -> Sql -> Sql
forall a b. (a -> b) -> a -> b
$ Sql
"mkUpdateText: BackendSpecificUpdate " Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
up Sql -> Sql -> Sql
forall a. Semigroup a => a -> a -> a
<> Sql
" not supported"
  where
    n :: Sql
n = DBName -> Sql
escapeName (DBName -> Sql)
-> (Update record -> DBName) -> Update record -> Sql
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldDef -> DBName
fieldDB (FieldDef -> DBName)
-> (Update record -> FieldDef) -> Update record -> DBName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Update record -> FieldDef
forall v. PersistEntity v => Update v -> FieldDef
updateFieldDef (Update record -> Sql) -> Update record -> Sql
forall a b. (a -> b) -> a -> b
$ Update record
x

parenWrapped :: Text -> Text
parenWrapped :: Sql -> Sql
parenWrapped Sql
t = [Sql] -> Sql
T.concat [Sql
"(", Sql
t, Sql
")"]