module Database.Persist.Sql.Orphan.PersistQuery
( deleteWhereCount
, updateWhereCount
) where
import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Class
import Database.Persist.Sql.Raw
import Database.Persist.Sql.Orphan.PersistStore ()
import qualified Data.Text as T
import Data.Text (Text)
import Data.Monoid (Monoid (..), (<>))
import Data.Int (Int64)
import Control.Monad.Logger
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Exception (throwIO)
import qualified Data.Conduit.List as CL
import Data.Conduit
instance (MonadResource m, MonadLogger m) => PersistQuery (SqlPersistT m) where
update _ [] = return ()
update k upds = do
conn <- askSqlConn
let go'' n Assign = n <> "=?"
go'' n Add = T.concat [n, "=", n, "+?"]
go'' n Subtract = T.concat [n, "=", n, "-?"]
go'' n Multiply = T.concat [n, "=", n, "*?"]
go'' n Divide = T.concat [n, "=", n, "/?"]
let go' (x, pu) = go'' (connEscapeName conn x) pu
let sql = T.concat
[ "UPDATE "
, connEscapeName conn $ entityDB t
, " SET "
, T.intercalate "," $ map (go' . go) upds
, " WHERE "
, connEscapeName conn $ entityID t
, "=?"
]
rawExecute sql $
map updatePersistValue upds `mappend` [unKey k]
where
t = entityDef $ dummyFromKey k
go x = (fieldDB $ updateFieldDef x, updateUpdate x)
count filts = do
conn <- askSqlConn
let wher = if null filts
then ""
else filterClause False conn filts
let sql = mconcat
[ "SELECT COUNT(*) FROM "
, connEscapeName conn $ entityDB t
, wher
]
rawQuery sql (getFiltsValues conn filts) $$ do
Just [PersistInt64 i] <- CL.head
return $ fromIntegral i
where
t = entityDef $ dummyFromFilts filts
selectSource filts opts = do
conn <- lift askSqlConn
rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse
where
(limit, offset, orders) = limitOffsetOrder opts
parse vals =
case fromPersistValues' vals of
Left s -> liftIO $ throwIO $ PersistMarshalError s
Right row -> return row
t = entityDef $ dummyFromFilts filts
fromPersistValues' (PersistInt64 x:xs) = do
case fromPersistValues xs of
Left e -> Left e
Right xs' -> Right (Entity (Key $ PersistInt64 x) xs')
fromPersistValues' _ = Left "error in fromPersistValues'"
wher conn = if null filts
then ""
else filterClause False conn filts
ord conn =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
lim conn = case (limit, offset) of
(0, 0) -> ""
(0, _) -> T.cons ' ' $ connNoLimit conn
(_, _) -> " LIMIT " <> T.pack (show limit)
off = if offset == 0
then ""
else " OFFSET " <> T.pack (show offset)
cols conn = T.intercalate ","
$ (connEscapeName conn $ entityID t)
: map (connEscapeName conn . fieldDB) (entityFields t)
sql conn = mconcat
[ "SELECT "
, cols conn
, " FROM "
, connEscapeName conn $ entityDB t
, wher conn
, ord conn
, lim conn
, off
]
selectKeys filts opts = do
conn <- lift askSqlConn
rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse
where
parse [PersistInt64 i] = return $ Key $ PersistInt64 i
parse y = liftIO $ throwIO $ PersistMarshalError $ "Unexpected in selectKeys: " <> T.pack (show y)
t = entityDef $ dummyFromFilts filts
wher conn = if null filts
then ""
else filterClause False conn filts
sql conn = mconcat
[ "SELECT "
, connEscapeName conn $ entityID t
, " FROM "
, connEscapeName conn $ entityDB t
, wher conn
, ord conn
, lim conn
, off
]
(limit, offset, orders) = limitOffsetOrder opts
ord conn =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
lim conn = case (limit, offset) of
(0, 0) -> ""
(0, _) -> T.cons ' ' $ connNoLimit conn
(_, _) -> " LIMIT " <> T.pack (show limit)
off = if offset == 0
then ""
else " OFFSET " <> T.pack (show offset)
deleteWhere filts = do
_ <- deleteWhereCount filts
return ()
updateWhere filts upds = do
_ <- updateWhereCount filts upds
return ()
deleteWhereCount :: (PersistEntity val, MonadSqlPersist m)
=> [Filter val]
-> m Int64
deleteWhereCount filts = do
conn <- askSqlConn
let t = entityDef $ dummyFromFilts filts
let wher = if null filts
then ""
else filterClause False conn filts
sql = mconcat
[ "DELETE FROM "
, connEscapeName conn $ entityDB t
, wher
]
rawExecuteCount sql $ getFiltsValues conn filts
updateWhereCount :: (PersistEntity val, MonadSqlPersist m)
=> [Filter val]
-> [Update val]
-> m Int64
updateWhereCount _ [] = return 0
updateWhereCount filts upds = do
conn <- askSqlConn
let wher = if null filts
then ""
else filterClause False conn filts
let sql = mconcat
[ "UPDATE "
, connEscapeName conn $ entityDB t
, " SET "
, T.intercalate "," $ map (go' conn . go) upds
, wher
]
let dat = map updatePersistValue upds `mappend`
getFiltsValues conn filts
rawExecuteCount sql dat
where
t = entityDef $ dummyFromFilts filts
go'' n Assign = n <> "=?"
go'' n Add = mconcat [n, "=", n, "+?"]
go'' n Subtract = mconcat [n, "=", n, "-?"]
go'' n Multiply = mconcat [n, "=", n, "*?"]
go'' n Divide = mconcat [n, "=", n, "/?"]
go' conn (x, pu) = go'' (connEscapeName conn x) pu
go x = (fieldDB $ updateFieldDef x, updateUpdate x)
updateFieldDef :: PersistEntity v => Update v -> FieldDef SqlType
updateFieldDef (Update f _ _) = persistFieldDef f
dummyFromFilts :: [Filter v] -> Maybe v
dummyFromFilts _ = Nothing
getFiltsValues :: forall val. PersistEntity val => Connection -> [Filter val] -> [PersistValue]
getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo
data OrNull = OrNullYes | OrNullNo
filterClauseHelper :: PersistEntity val
=> Bool
-> Bool
-> Connection
-> OrNull
-> [Filter val]
-> (Text, [PersistValue])
filterClauseHelper includeTable includeWhere conn orNull filters =
(if not (T.null sql) && includeWhere
then " WHERE " <> sql
else sql, vals)
where
(sql, vals) = combineAND filters
combineAND = combine " AND "
combine s fs =
(T.intercalate s $ map wrapP a, mconcat b)
where
(a, b) = unzip $ map go fs
wrapP x = T.concat ["(", x, ")"]
go (BackendFilter _) = error "BackendFilter not expected"
go (FilterAnd []) = ("1=1", [])
go (FilterAnd fs) = combineAND fs
go (FilterOr []) = ("1=0", [])
go (FilterOr fs) = combine " OR " fs
go (Filter field value pfilter) =
case (isNull, pfilter, varCount) of
(True, Eq, _) -> (name <> " IS NULL", [])
(True, Ne, _) -> (name <> " IS NOT NULL", [])
(False, Ne, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " <> "
, qmarks
, ")"
], notNullVals)
(_, In, 0) -> ("1=2" <> orNullSuffix, [])
(False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals)
(True, In, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " IN "
, qmarks
, ")"
], notNullVals)
(_, NotIn, 0) -> ("1=1", [])
(False, NotIn, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
(True, NotIn, _) -> (T.concat
[ "("
, name
, " IS NOT NULL AND "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
_ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals)
where
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
filterValueToPersistValues v = map toPersistValue $ either return id v
orNullSuffix =
case orNull of
OrNullYes -> mconcat [" OR ", name, " IS NULL"]
OrNullNo -> ""
isNull = any (== PersistNull) allVals
notNullVals = filter (/= PersistNull) allVals
allVals = filterValueToPersistValues value
tn = connEscapeName conn $ entityDB
$ entityDef $ dummyFromFilts [Filter field value pfilter]
name =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldDB $ persistFieldDef field
qmarks = case value of
Left _ -> "?"
Right x ->
let x' = filter (/= PersistNull) $ map toPersistValue x
in "(" <> T.intercalate "," (map (const "?") x') <> ")"
varCount = case value of
Left _ -> 1
Right x -> length x
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s
updatePersistValue :: Update v -> PersistValue
updatePersistValue (Update _ v _) = toPersistValue v
filterClause :: PersistEntity val
=> Bool
-> Connection
-> [Filter val]
-> Text
filterClause b c = fst . filterClauseHelper b True c OrNullNo
orderClause :: PersistEntity val
=> Bool
-> Connection
-> SelectOpt val
-> Text
orderClause includeTable conn o =
case o of
Asc x -> name $ persistFieldDef x
Desc x -> name (persistFieldDef x) <> " DESC"
_ -> error $ "orderClause: expected Asc or Desc, not limit or offset"
where
dummyFromOrder :: SelectOpt a -> Maybe a
dummyFromOrder _ = Nothing
tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o
name x =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldDB x
dummyFromKey :: KeyBackend SqlBackend v -> Maybe v
dummyFromKey _ = Nothing