{-# LANGUAGE RankNTypes #-} {-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Database.Persist.Sql.Orphan.PersistQuery ( deleteWhereCount , updateWhereCount , decorateSQLWithLimitOffset ) where import Database.Persist import Database.Persist.Sql.Types import Database.Persist.Sql.Class import Database.Persist.Sql.Raw import Database.Persist.Sql.Orphan.PersistStore () import Database.Persist.Sql.Internal (convertKey) import qualified Data.Text as T import Data.Text (Text) import Data.Monoid (Monoid (..), (<>)) import Data.Int (Int64) import Control.Monad.Logger import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Resource (MonadResource) import Control.Exception (throwIO) import qualified Data.Conduit.List as CL import Data.Conduit import Data.ByteString.Char8 (readInteger) import Data.Maybe (isJust) import Data.List (transpose, inits, find) -- orphaned instance for convenience of modularity instance (MonadResource m, MonadLogger m) => PersistQuery (SqlPersistT m) where update _ [] = return () update k upds = do conn <- askSqlConn let go'' n Assign = n <> "=?" go'' n Add = T.concat [n, "=", n, "+?"] go'' n Subtract = T.concat [n, "=", n, "-?"] go'' n Multiply = T.concat [n, "=", n, "*?"] go'' n Divide = T.concat [n, "=", n, "/?"] let go' (x, pu) = go'' (connEscapeName conn x) pu let composite = isJust $ entityPrimary t let wher = case entityPrimary t of Just pdef -> T.intercalate " AND " $ map (\fld -> connEscapeName conn (snd fld) <> "=? ") $ primaryFields pdef Nothing -> connEscapeName conn (entityID t) <> "=?" let sql = T.concat [ "UPDATE " , connEscapeName conn $ entityDB t , " SET " , T.intercalate "," $ map (go' . go) upds , " WHERE " , wher ] rawExecute sql $ map updatePersistValue upds `mappend` (convertKey composite k) where t = entityDef $ dummyFromKey k go x = (fieldDB $ updateFieldDef x, updateUpdate x) count filts = do conn <- askSqlConn let wher = if null filts then "" else filterClause False conn filts let sql = mconcat [ "SELECT COUNT(*) FROM " , connEscapeName conn $ entityDB t , wher ] rawQuery sql (getFiltsValues conn filts) $$ do mm <- CL.head case mm of Just [PersistInt64 i] -> return $ fromIntegral i Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle Just [PersistByteString i] -> case readInteger i of -- gb mssql Just (ret,"") -> return $ fromIntegral ret xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]" Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]" Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]" where t = entityDef $ dummyFromFilts filts selectSource filts opts = do conn <- lift askSqlConn rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse where composite = isJust $ entityPrimary t (limit, offset, orders) = limitOffsetOrder opts parse vals = case entityPrimary t of Just pdef -> let pks = map fst $ primaryFields pdef keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ entityFields t) vals in case fromPersistValuesComposite' keyvals vals of Left s -> liftIO $ throwIO $ PersistMarshalError s Right row -> return row Nothing -> case fromPersistValues' vals of Left s -> liftIO $ throwIO $ PersistMarshalError s Right row -> return row t = entityDef $ dummyFromFilts filts fromPersistValues' (PersistInt64 x:xs) = case fromPersistValues xs of Left e -> Left e Right xs' -> Right (Entity (Key $ PersistInt64 x) xs') fromPersistValues' (PersistDouble x:xs) = -- oracle returns Double case fromPersistValues xs of Left e -> Left e Right xs' -> Right (Entity (Key $ PersistInt64 (truncate x)) xs') -- convert back to int64 fromPersistValues' xs = Left $ T.pack ("error in fromPersistValues' xs=" ++ show xs) fromPersistValuesComposite' keyvals xs = case fromPersistValues xs of Left e -> Left e Right xs' -> Right (Entity (Key $ PersistList keyvals) xs') wher conn = if null filts then "" else filterClause False conn filts ord conn = case map (orderClause False conn) orders of [] -> "" ords -> " ORDER BY " <> T.intercalate "," ords cols conn = T.intercalate "," $ ((if composite then [] else [connEscapeName conn $ entityID t]) <> map (connEscapeName conn . fieldDB) (entityFields t)) sql conn = connLimitOffset conn (limit,offset) (not (null orders)) $ mconcat [ "SELECT " , cols conn , " FROM " , connEscapeName conn $ entityDB t , wher conn , ord conn ] selectKeys filts opts = do conn <- lift askSqlConn rawQuery (sql conn) (getFiltsValues conn filts) $= CL.mapM parse where t = entityDef $ dummyFromFilts filts cols conn = case entityPrimary t of Just pdef -> T.intercalate "," $ map (connEscapeName conn . snd) $ primaryFields pdef Nothing -> connEscapeName conn $ entityID t wher conn = if null filts then "" else filterClause False conn filts sql conn = connLimitOffset conn (limit,offset) (not (null orders)) $ mconcat [ "SELECT " , cols conn , " FROM " , connEscapeName conn $ entityDB t , wher conn , ord conn ] (limit, offset, orders) = limitOffsetOrder opts ord conn = case map (orderClause False conn) orders of [] -> "" ords -> " ORDER BY " <> T.intercalate "," ords parse xs = case entityPrimary t of Nothing -> case xs of [PersistInt64 x] -> return $ Key $ PersistInt64 x [PersistDouble x] -> return $ Key $ PersistInt64 (truncate x) -- oracle returns Double _ -> liftIO $ throwIO $ PersistMarshalError $ "Unexpected in selectKeys False: " <> T.pack (show xs) Just pdef -> let pks = map fst $ primaryFields pdef keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ entityFields t) xs in return $ Key $ PersistList keyvals deleteWhere filts = do _ <- deleteWhereCount filts return () updateWhere filts upds = do _ <- updateWhereCount filts upds return () -- | Same as 'deleteWhere', but returns the number of rows affected. -- -- Since 1.1.5 deleteWhereCount :: (PersistEntity val, MonadSqlPersist m) => [Filter val] -> m Int64 deleteWhereCount filts = do conn <- askSqlConn let t = entityDef $ dummyFromFilts filts let wher = if null filts then "" else filterClause False conn filts sql = mconcat [ "DELETE FROM " , connEscapeName conn $ entityDB t , wher ] rawExecuteCount sql $ getFiltsValues conn filts -- | Same as 'updateWhere', but returns the number of rows affected. -- -- Since 1.1.5 updateWhereCount :: (PersistEntity val, MonadSqlPersist m) => [Filter val] -> [Update val] -> m Int64 updateWhereCount _ [] = return 0 updateWhereCount filts upds = do conn <- askSqlConn let wher = if null filts then "" else filterClause False conn filts let sql = mconcat [ "UPDATE " , connEscapeName conn $ entityDB t , " SET " , T.intercalate "," $ map (go' conn . go) upds , wher ] let dat = map updatePersistValue upds `mappend` getFiltsValues conn filts rawExecuteCount sql dat where t = entityDef $ dummyFromFilts filts go'' n Assign = n <> "=?" go'' n Add = mconcat [n, "=", n, "+?"] go'' n Subtract = mconcat [n, "=", n, "-?"] go'' n Multiply = mconcat [n, "=", n, "*?"] go'' n Divide = mconcat [n, "=", n, "/?"] go' conn (x, pu) = go'' (connEscapeName conn x) pu go x = (fieldDB $ updateFieldDef x, updateUpdate x) updateFieldDef :: PersistEntity v => Update v -> FieldDef SqlType updateFieldDef (Update f _ _) = persistFieldDef f dummyFromFilts :: [Filter v] -> Maybe v dummyFromFilts _ = Nothing getFiltsValues :: forall val. PersistEntity val => Connection -> [Filter val] -> [PersistValue] getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo data OrNull = OrNullYes | OrNullNo filterClauseHelper :: PersistEntity val => Bool -- ^ include table name? -> Bool -- ^ include WHERE? -> Connection -> OrNull -> [Filter val] -> (Text, [PersistValue]) filterClauseHelper includeTable includeWhere conn orNull filters = (if not (T.null sql) && includeWhere then " WHERE " <> sql else sql, vals) where (sql, vals) = combineAND filters combineAND = combine " AND " combine s fs = (T.intercalate s $ map wrapP a, mconcat b) where (a, b) = unzip $ map go fs wrapP x = T.concat ["(", x, ")"] go (BackendFilter _) = error "BackendFilter not expected" go (FilterAnd []) = ("1=1", []) go (FilterAnd fs) = combineAND fs go (FilterOr []) = ("1=0", []) go (FilterOr fs) = combine " OR " fs go (Filter field value pfilter) = let t = entityDef $ dummyFromFilts [Filter field value pfilter] in case (fieldDB (persistFieldDef field) == DBName "id", entityPrimary t, allVals) of -- need to check the id field in a safer way: entityId? (True, Just pdef, (PersistList ys:_)) -> if length (primaryFields pdef) /= length ys then error $ "wrong number of entries in primaryFields vs PersistList allVals=" ++ show allVals else case (allVals, pfilter, isCompFilter pfilter) of ([PersistList xs], Eq, _) -> let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "? ") (primaryFields pdef)) in (wrapSql sqlcl,xs) ([PersistList xs], Ne, _) -> let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "? ") (primaryFields pdef)) in (wrapSql sqlcl,xs) (_, In, _) -> let xxs = transpose (map fromPersistList allVals) sqls=map (\(a,xs) -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (primaryFields pdef) xxs) in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs) (_, NotIn, _) -> let xxs = transpose (map fromPersistList allVals) sqls=map (\(a,xs) -> connEscapeName conn (snd a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (primaryFields pdef) xxs) in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs) ([PersistList xs], _, True) -> let zs = tail (inits (primaryFields pdef)) sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs sql2 islast a = connEscapeName conn (snd a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? " sqlcl = T.intercalate " or " sql1 in (wrapSql sqlcl, concat (tail (inits xs))) (_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys" _ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals (True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef _ -> case (isNull, pfilter, varCount) of (True, Eq, _) -> (name <> " IS NULL", []) (True, Ne, _) -> (name <> " IS NOT NULL", []) (False, Ne, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " <> " , qmarks , ")" ], notNullVals) -- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since -- not all databases support those words directly. (_, In, 0) -> ("1=2" <> orNullSuffix, []) (False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals) (True, In, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " IN " , qmarks , ")" ], notNullVals) (_, NotIn, 0) -> ("1=1", []) (False, NotIn, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " NOT IN " , qmarks , ")" ], notNullVals) (True, NotIn, _) -> (T.concat [ "(" , name , " IS NOT NULL AND " , name , " NOT IN " , qmarks , ")" ], notNullVals) _ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals) where isCompFilter Lt = True isCompFilter Le = True isCompFilter Gt = True isCompFilter Ge = True isCompFilter _ = False wrapSql sqlcl = "(" <> sqlcl <> ")" fromPersistList (PersistList xs) = xs fromPersistList other = error $ "expected PersistList but found " ++ show other filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue] filterValueToPersistValues v = map toPersistValue $ either return id v orNullSuffix = case orNull of OrNullYes -> mconcat [" OR ", name, " IS NULL"] OrNullNo -> "" isNull = any (== PersistNull) allVals notNullVals = filter (/= PersistNull) allVals allVals = filterValueToPersistValues value tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromFilts [Filter field value pfilter] name = (if includeTable then ((tn <> ".") <>) else id) $ connEscapeName conn $ fieldDB $ persistFieldDef field qmarks = case value of Left _ -> "?" Right x -> let x' = filter (/= PersistNull) $ map toPersistValue x in "(" <> T.intercalate "," (map (const "?") x') <> ")" varCount = case value of Left _ -> 1 Right x -> length x showSqlFilter Eq = "=" showSqlFilter Ne = "<>" showSqlFilter Gt = ">" showSqlFilter Lt = "<" showSqlFilter Ge = ">=" showSqlFilter Le = "<=" showSqlFilter In = " IN " showSqlFilter NotIn = " NOT IN " showSqlFilter (BackendSpecificFilter s) = s updatePersistValue :: Update v -> PersistValue updatePersistValue (Update _ v _) = toPersistValue v filterClause :: PersistEntity val => Bool -- ^ include table name? -> Connection -> [Filter val] -> Text filterClause b c = fst . filterClauseHelper b True c OrNullNo orderClause :: PersistEntity val => Bool -- ^ include the table name -> Connection -> SelectOpt val -> Text orderClause includeTable conn o = case o of Asc x -> name $ persistFieldDef x Desc x -> name (persistFieldDef x) <> " DESC" _ -> error $ "orderClause: expected Asc or Desc, not limit or offset" where dummyFromOrder :: SelectOpt a -> Maybe a dummyFromOrder _ = Nothing tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o name x = (if includeTable then ((tn <> ".") <>) else id) $ connEscapeName conn $ fieldDB x dummyFromKey :: KeyBackend SqlBackend v -> Maybe v dummyFromKey _ = Nothing -- | Generates sql for limit and offset for postgres, sqlite and mysql. decorateSQLWithLimitOffset::Text -> (Int,Int) -> Bool -> Text -> Text decorateSQLWithLimitOffset nolimit (limit,offset) _ sql = let lim = case (limit, offset) of (0, 0) -> "" (0, _) -> T.cons ' ' nolimit (_, _) -> " LIMIT " <> T.pack (show limit) off = if offset == 0 then "" else " OFFSET " <> T.pack (show offset) in mconcat [ sql , lim , off ]