module Database.Persist.Sql.Orphan.PersistQuery
( deleteWhereCount
, updateWhereCount
, decorateSQLWithLimitOffset
) where
import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Class
import Database.Persist.Sql.Raw
import Database.Persist.Sql.Orphan.PersistStore ()
import Database.Persist.Sql.Internal (convertKey)
import qualified Data.Text as T
import Data.Text (Text)
import Data.Monoid (Monoid (..), (<>))
import Data.Int (Int64)
import Control.Monad.Logger
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Resource (MonadResource)
import Control.Exception (throwIO)
import qualified Data.Conduit.List as CL
import Data.Conduit
import Data.ByteString.Char8 (readInteger)
import Data.Maybe (isJust)
import Data.List (transpose, inits, find)
instance (MonadResource m, MonadLogger m) => PersistQuery (SqlPersistT m) where
update _ [] = return ()
update k upds = do
conn <- askSqlConn
let go'' n Assign = n <> "=?"
go'' n Add = T.concat [n, "=", n, "+?"]
go'' n Subtract = T.concat [n, "=", n, "-?"]
go'' n Multiply = T.concat [n, "=", n, "*?"]
go'' n Divide = T.concat [n, "=", n, "/?"]
let go' (x, pu) = go'' (connEscapeName conn x) pu
let composite = isJust $ entityPrimary t
let wher = case entityPrimary t of
Just pdef -> T.intercalate " AND " $ map (\fld -> connEscapeName conn (snd fld) <> "=? ") $ primaryFields pdef
Nothing -> connEscapeName conn (entityID t) <> "=?"
let sql = T.concat
[ "UPDATE "
, connEscapeName conn $ entityDB t
, " SET "
, T.intercalate "," $ map (go' . go) upds
, " WHERE "
, wher
]
rawExecute sql $
map updatePersistValue upds `mappend` (convertKey composite k)
where
t = entityDef $ dummyFromKey k
go x = (fieldDB $ updateFieldDef x, updateUpdate x)
count filts = do
conn <- askSqlConn
let wher = if null filts
then ""
else filterClause False conn filts
let sql = mconcat
[ "SELECT COUNT(*) FROM "
, connEscapeName conn $ entityDB t
, wher
]
rawQuery sql (getFiltsValues conn filts) $$ do
mm <- CL.head
case mm of
Just [PersistInt64 i] -> return $ fromIntegral i
Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64)
Just [PersistByteString i] -> case readInteger i of
Just (ret,"") -> return $ fromIntegral ret
xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]"
Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]"
Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]"
where
t = entityDef $ dummyFromFilts filts
selectSource filts opts = do
conn <- lift askSqlConn
rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse
where
composite = isJust $ entityPrimary t
(limit, offset, orders) = limitOffsetOrder opts
parse vals =
case entityPrimary t of
Just pdef ->
let pks = map fst $ primaryFields pdef
keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ entityFields t) vals
in case fromPersistValuesComposite' keyvals vals of
Left s -> liftIO $ throwIO $ PersistMarshalError s
Right row -> return row
Nothing ->
case fromPersistValues' vals of
Left s -> liftIO $ throwIO $ PersistMarshalError s
Right row -> return row
t = entityDef $ dummyFromFilts filts
fromPersistValues' (PersistInt64 x:xs) =
case fromPersistValues xs of
Left e -> Left e
Right xs' -> Right (Entity (Key $ PersistInt64 x) xs')
fromPersistValues' (PersistDouble x:xs) =
case fromPersistValues xs of
Left e -> Left e
Right xs' -> Right (Entity (Key $ PersistInt64 (truncate x)) xs')
fromPersistValues' xs = Left $ T.pack ("error in fromPersistValues' xs=" ++ show xs)
fromPersistValuesComposite' keyvals xs =
case fromPersistValues xs of
Left e -> Left e
Right xs' -> Right (Entity (Key $ PersistList keyvals) xs')
wher conn = if null filts
then ""
else filterClause False conn filts
ord conn =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
cols conn = T.intercalate ","
$ ((if composite then [] else [connEscapeName conn $ entityID t])
<> map (connEscapeName conn . fieldDB) (entityFields t))
sql conn = connLimitOffset conn (limit,offset) (not (null orders)) $ mconcat
[ "SELECT "
, cols conn
, " FROM "
, connEscapeName conn $ entityDB t
, wher conn
, ord conn
]
selectKeys filts opts = do
conn <- lift askSqlConn
rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse
where
t = entityDef $ dummyFromFilts filts
cols conn = case entityPrimary t of
Just pdef -> T.intercalate "," $ map (connEscapeName conn . snd) $ primaryFields pdef
Nothing -> connEscapeName conn $ entityID t
wher conn = if null filts
then ""
else filterClause False conn filts
sql conn = connLimitOffset conn (limit,offset) (not (null orders)) $ mconcat
[ "SELECT "
, cols conn
, " FROM "
, connEscapeName conn $ entityDB t
, wher conn
, ord conn
]
(limit, offset, orders) = limitOffsetOrder opts
ord conn =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
parse xs = case entityPrimary t of
Nothing ->
case xs of
[PersistInt64 x] -> return $ Key $ PersistInt64 x
[PersistDouble x] -> return $ Key $ PersistInt64 (truncate x)
_ -> liftIO $ throwIO $ PersistMarshalError $ "Unexpected in selectKeys False: " <> T.pack (show xs)
Just pdef ->
let pks = map fst $ primaryFields pdef
keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ entityFields t) xs
in return $ Key $ PersistList keyvals
deleteWhere filts = do
_ <- deleteWhereCount filts
return ()
updateWhere filts upds = do
_ <- updateWhereCount filts upds
return ()
deleteWhereCount :: (PersistEntity val, MonadSqlPersist m)
=> [Filter val]
-> m Int64
deleteWhereCount filts = do
conn <- askSqlConn
let t = entityDef $ dummyFromFilts filts
let wher = if null filts
then ""
else filterClause False conn filts
sql = mconcat
[ "DELETE FROM "
, connEscapeName conn $ entityDB t
, wher
]
rawExecuteCount sql $ getFiltsValues conn filts
updateWhereCount :: (PersistEntity val, MonadSqlPersist m)
=> [Filter val]
-> [Update val]
-> m Int64
updateWhereCount _ [] = return 0
updateWhereCount filts upds = do
conn <- askSqlConn
let wher = if null filts
then ""
else filterClause False conn filts
let sql = mconcat
[ "UPDATE "
, connEscapeName conn $ entityDB t
, " SET "
, T.intercalate "," $ map (go' conn . go) upds
, wher
]
let dat = map updatePersistValue upds `mappend`
getFiltsValues conn filts
rawExecuteCount sql dat
where
t = entityDef $ dummyFromFilts filts
go'' n Assign = n <> "=?"
go'' n Add = mconcat [n, "=", n, "+?"]
go'' n Subtract = mconcat [n, "=", n, "-?"]
go'' n Multiply = mconcat [n, "=", n, "*?"]
go'' n Divide = mconcat [n, "=", n, "/?"]
go' conn (x, pu) = go'' (connEscapeName conn x) pu
go x = (fieldDB $ updateFieldDef x, updateUpdate x)
updateFieldDef :: PersistEntity v => Update v -> FieldDef SqlType
updateFieldDef (Update f _ _) = persistFieldDef f
dummyFromFilts :: [Filter v] -> Maybe v
dummyFromFilts _ = Nothing
getFiltsValues :: forall val. PersistEntity val => Connection -> [Filter val] -> [PersistValue]
getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo
data OrNull = OrNullYes | OrNullNo
filterClauseHelper :: PersistEntity val
=> Bool
-> Bool
-> Connection
-> OrNull
-> [Filter val]
-> (Text, [PersistValue])
filterClauseHelper includeTable includeWhere conn orNull filters =
(if not (T.null sql) && includeWhere
then " WHERE " <> sql
else sql, vals)
where
(sql, vals) = combineAND filters
combineAND = combine " AND "
combine s fs =
(T.intercalate s $ map wrapP a, mconcat b)
where
(a, b) = unzip $ map go fs
wrapP x = T.concat ["(", x, ")"]
go (BackendFilter _) = error "BackendFilter not expected"
go (FilterAnd []) = ("1=1", [])
go (FilterAnd fs) = combineAND fs
go (FilterOr []) = ("1=0", [])
go (FilterOr fs) = combine " OR " fs
go (Filter field value pfilter) =
let t = entityDef $ dummyFromFilts [Filter field value pfilter]
in case (fieldDB (persistFieldDef field) == DBName "id", entityPrimary t, allVals) of
(True, Just pdef, (PersistList ys:_)) ->
if length (primaryFields pdef) /= length ys
then error $ "wrong number of entries in primaryFields vs PersistList allVals=" ++ show allVals
else
case (allVals, pfilter, isCompFilter pfilter) of
([PersistList xs], Eq, _) ->
let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "? ") (primaryFields pdef))
in (wrapSql sqlcl,xs)
([PersistList xs], Ne, _) ->
let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "? ") (primaryFields pdef))
in (wrapSql sqlcl,xs)
(_, In, _) ->
let xxs = transpose (map fromPersistList allVals)
sqls=map (\(a,xs) -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (primaryFields pdef) xxs)
in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs)
(_, NotIn, _) ->
let xxs = transpose (map fromPersistList allVals)
sqls=map (\(a,xs) -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (primaryFields pdef) xxs)
in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs)
([PersistList xs], _, True) ->
let zs = tail (inits (primaryFields pdef))
sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs
sql2 islast a = connEscapeName conn (snd a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? "
sqlcl = T.intercalate " or " sql1
in (wrapSql sqlcl, concat (tail (inits xs)))
(_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys"
_ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals
(True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef
_ -> case (isNull, pfilter, varCount) of
(True, Eq, _) -> (name <> " IS NULL", [])
(True, Ne, _) -> (name <> " IS NOT NULL", [])
(False, Ne, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " <> "
, qmarks
, ")"
], notNullVals)
(_, In, 0) -> ("1=2" <> orNullSuffix, [])
(False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals)
(True, In, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " IN "
, qmarks
, ")"
], notNullVals)
(_, NotIn, 0) -> ("1=1", [])
(False, NotIn, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
(True, NotIn, _) -> (T.concat
[ "("
, name
, " IS NOT NULL AND "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
_ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals)
where
isCompFilter Lt = True
isCompFilter Le = True
isCompFilter Gt = True
isCompFilter Ge = True
isCompFilter _ = False
wrapSql sqlcl = "(" <> sqlcl <> ")"
fromPersistList (PersistList xs) = xs
fromPersistList other = error $ "expected PersistList but found " ++ show other
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
filterValueToPersistValues v = map toPersistValue $ either return id v
orNullSuffix =
case orNull of
OrNullYes -> mconcat [" OR ", name, " IS NULL"]
OrNullNo -> ""
isNull = any (== PersistNull) allVals
notNullVals = filter (/= PersistNull) allVals
allVals = filterValueToPersistValues value
tn = connEscapeName conn $ entityDB
$ entityDef $ dummyFromFilts [Filter field value pfilter]
name =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldDB $ persistFieldDef field
qmarks = case value of
Left _ -> "?"
Right x ->
let x' = filter (/= PersistNull) $ map toPersistValue x
in "(" <> T.intercalate "," (map (const "?") x') <> ")"
varCount = case value of
Left _ -> 1
Right x -> length x
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s
updatePersistValue :: Update v -> PersistValue
updatePersistValue (Update _ v _) = toPersistValue v
filterClause :: PersistEntity val
=> Bool
-> Connection
-> [Filter val]
-> Text
filterClause b c = fst . filterClauseHelper b True c OrNullNo
orderClause :: PersistEntity val
=> Bool
-> Connection
-> SelectOpt val
-> Text
orderClause includeTable conn o =
case o of
Asc x -> name $ persistFieldDef x
Desc x -> name (persistFieldDef x) <> " DESC"
_ -> error $ "orderClause: expected Asc or Desc, not limit or offset"
where
dummyFromOrder :: SelectOpt a -> Maybe a
dummyFromOrder _ = Nothing
tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o
name x =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldDB x
dummyFromKey :: KeyBackend SqlBackend v -> Maybe v
dummyFromKey _ = Nothing
decorateSQLWithLimitOffset::Text -> (Int,Int) -> Bool -> Text -> Text
decorateSQLWithLimitOffset nolimit (limit,offset) _ sql =
let
lim = case (limit, offset) of
(0, 0) -> ""
(0, _) -> T.cons ' ' nolimit
(_, _) -> " LIMIT " <> T.pack (show limit)
off = if offset == 0
then ""
else " OFFSET " <> T.pack (show offset)
in mconcat
[ sql
, lim
, off
]