{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE UndecidableInstances #-} -- FIXME module Database.Persist.Query.GenericSql ( PersistQuery (..) , SqlPersist (..) , filterClauseNoWhere , filterClauseNoWhereOrNull , getFiltsValues , selectSourceConn , dummyFromFilts , orderClause , deleteWhereCount , updateWhereCount ) where import qualified Prelude import Prelude hiding ((++), unlines, concat, show) import Data.Text (Text, pack, concat) import Database.Persist.Store import Database.Persist.Query.Internal import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal import qualified Database.Persist.GenericSql.Raw as R import Control.Monad.IO.Class import Control.Monad.Trans.Class import Control.Monad.Trans.Reader import Data.Conduit import qualified Data.Conduit.List as CL import Data.Int (Int64) import Control.Exception (throwIO) import qualified Data.Text as T import Database.Persist.EntityDef import Data.Monoid (Monoid, mappend, mconcat) import Control.Monad.Logger (MonadLogger) -- orphaned instance for convenience of modularity instance (MonadResource m, MonadLogger m) => PersistQuery (SqlPersist m) where update _ [] = return () update k upds = do conn <- SqlPersist ask let go'' n Assign = n ++ "=?" go'' n Add = concat [n, "=", n, "+?"] go'' n Subtract = concat [n, "=", n, "-?"] go'' n Multiply = concat [n, "=", n, "*?"] go'' n Divide = concat [n, "=", n, "/?"] let go' (x, pu) = go'' (escapeName conn x) pu let sql = concat [ "UPDATE " , escapeName conn $ entityDB t , " SET " , T.intercalate "," $ map (go' . go) upds , " WHERE " , escapeName conn $ entityID t , "=?" ] execute' sql $ map updatePersistValue upds `mappend` [unKey k] where t = entityDef $ dummyFromKey k go x = (fieldDB $ updateFieldDef x, updateUpdate x) count filts = do conn <- SqlPersist ask let wher = if null filts then "" else filterClause False conn filts let sql = concat [ "SELECT COUNT(*) FROM " , escapeName conn $ entityDB t , wher ] R.withStmt sql (getFiltsValues conn filts) $$ do Just [PersistInt64 i] <- CL.head return $ fromIntegral i where t = entityDef $ dummyFromFilts filts selectSource filts opts = do conn <- lift $ SqlPersist ask R.withStmt (sql conn) (getFiltsValues conn filts) $= CL.mapM parse where (limit, offset, orders) = limitOffsetOrder opts parse vals = case fromPersistValues' vals of Left s -> liftIO $ throwIO $ PersistMarshalError s Right row -> return row t = entityDef $ dummyFromFilts filts fromPersistValues' (PersistInt64 x:xs) = do case fromPersistValues xs of Left e -> Left e Right xs' -> Right (Entity (Key $ PersistInt64 x) xs') fromPersistValues' _ = Left "error in fromPersistValues'" wher conn = if null filts then "" else filterClause False conn filts ord conn = case map (orderClause False conn) orders of [] -> "" ords -> " ORDER BY " ++ T.intercalate "," ords lim conn = case (limit, offset) of (0, 0) -> "" (0, _) -> T.cons ' ' $ noLimit conn (_, _) -> " LIMIT " ++ show limit off = if offset == 0 then "" else " OFFSET " ++ show offset cols conn = T.intercalate "," $ (escapeName conn $ entityID t) : map (escapeName conn . fieldDB) (entityFields t) sql conn = concat [ "SELECT " , cols conn , " FROM " , escapeName conn $ entityDB t , wher conn , ord conn , lim conn , off ] selectKeys filts opts = do conn <- lift $ SqlPersist ask R.withStmt (sql conn) (getFiltsValues conn filts) $= CL.mapM parse where parse [PersistInt64 i] = return $ Key $ PersistInt64 i parse y = liftIO $ throwIO $ PersistMarshalError $ "Unexpected in selectKeys: " ++ show y t = entityDef $ dummyFromFilts filts wher conn = if null filts then "" else filterClause False conn filts sql conn = concat [ "SELECT " , escapeName conn $ entityID t , " FROM " , escapeName conn $ entityDB t , wher conn , ord conn , lim conn , off ] (limit, offset, orders) = limitOffsetOrder opts ord conn = case map (orderClause False conn) orders of [] -> "" ords -> " ORDER BY " ++ T.intercalate "," ords lim conn = case (limit, offset) of (0, 0) -> "" (0, _) -> T.cons ' ' $ noLimit conn (_, _) -> " LIMIT " ++ show limit off = if offset == 0 then "" else " OFFSET " ++ show offset deleteWhere filts = do _ <- deleteWhereCount filts return () updateWhere filts upds = do _ <- updateWhereCount filts upds return () -- | Same as 'deleteWhere', but returns the number of rows affected. -- -- Since 1.1.5 deleteWhereCount :: (PersistEntity val, MonadIO m, MonadLogger m) => [Filter val] -> SqlPersist m Int64 deleteWhereCount filts = do conn <- SqlPersist ask let t = entityDef $ dummyFromFilts filts let wher = if null filts then "" else filterClause False conn filts sql = concat [ "DELETE FROM " , escapeName conn $ entityDB t , wher ] R.executeCount sql $ getFiltsValues conn filts -- | Same as 'updateWhere', but returns the number of rows affected. -- -- Since 1.1.5 updateWhereCount :: (PersistEntity val, MonadIO m, MonadLogger m) => [Filter val] -> [Update val] -> SqlPersist m Int64 updateWhereCount _ [] = return 0 updateWhereCount filts upds = do conn <- SqlPersist ask let wher = if null filts then "" else filterClause False conn filts let sql = concat [ "UPDATE " , escapeName conn $ entityDB t , " SET " , T.intercalate "," $ map (go' conn . go) upds , wher ] let dat = map updatePersistValue upds `mappend` getFiltsValues conn filts R.executeCount sql dat where t = entityDef $ dummyFromFilts filts go'' n Assign = n ++ "=?" go'' n Add = concat [n, "=", n, "+?"] go'' n Subtract = concat [n, "=", n, "-?"] go'' n Multiply = concat [n, "=", n, "*?"] go'' n Divide = concat [n, "=", n, "/?"] go' conn (x, pu) = go'' (escapeName conn x) pu go x = (fieldDB $ updateFieldDef x, updateUpdate x) updatePersistValue :: Update v -> PersistValue updatePersistValue (Update _ v _) = toPersistValue v dummyFromKey :: KeyBackend R.SqlBackend v -> v dummyFromKey _ = error "dummyFromKey" execute' :: (MonadLogger m, MonadIO m) => Text -> [PersistValue] -> SqlPersist m () execute' = R.execute getFiltsValues :: forall val. PersistEntity val => Connection -> [Filter val] -> [PersistValue] getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo filterClause :: PersistEntity val => Bool -- ^ include table name? -> Connection -> [Filter val] -> Text filterClause b c = fst . filterClauseHelper b True c OrNullNo data OrNull = OrNullYes | OrNullNo filterClauseNoWhere :: PersistEntity val => Bool -- ^ include table name? -> Connection -> [Filter val] -> Text filterClauseNoWhere b c = fst . filterClauseHelper b False c OrNullNo filterClauseNoWhereOrNull :: PersistEntity val => Bool -- ^ include table name? -> Connection -> [Filter val] -> Text filterClauseNoWhereOrNull b c = fst . filterClauseHelper b False c OrNullYes filterClauseHelper :: PersistEntity val => Bool -- ^ include table name? -> Bool -- ^ include WHERE? -> Connection -> OrNull -> [Filter val] -> (Text, [PersistValue]) filterClauseHelper includeTable includeWhere conn orNull filters = (if not (T.null sql) && includeWhere then " WHERE " ++ sql else sql, vals) where (sql, vals) = combineAND filters combineAND = combine " AND " combine s fs = (T.intercalate s $ map wrapP a, mconcat b) where (a, b) = unzip $ map go fs wrapP x = T.concat ["(", x, ")"] go (BackendFilter _) = error "BackendFilter not expected" go (FilterAnd []) = ("1=1", []) go (FilterAnd fs) = combineAND fs go (FilterOr []) = ("1=0", []) go (FilterOr fs) = combine " OR " fs go (Filter field value pfilter) = case (isNull, pfilter, varCount) of (True, Eq, _) -> (name ++ " IS NULL", []) (True, Ne, _) -> (name ++ " IS NOT NULL", []) (False, Ne, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " <> " , qmarks , ")" ], notNullVals) -- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since -- not all databases support those words directly. (_, In, 0) -> ("1=2" ++ orNullSuffix, []) (False, In, _) -> (name ++ " IN " ++ qmarks ++ orNullSuffix, allVals) (True, In, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " IN " , qmarks , ")" ], notNullVals) (_, NotIn, 0) -> ("1=1", []) (False, NotIn, _) -> (T.concat [ "(" , name , " IS NULL OR " , name , " NOT IN " , qmarks , ")" ], notNullVals) (True, NotIn, _) -> (T.concat [ "(" , name , " IS NOT NULL AND " , name , " NOT IN " , qmarks , ")" ], notNullVals) _ -> (name ++ showSqlFilter pfilter ++ "?" ++ orNullSuffix, allVals) where filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue] filterValueToPersistValues v = map toPersistValue $ either return id v orNullSuffix = case orNull of OrNullYes -> concat [" OR ", name, " IS NULL"] OrNullNo -> "" isNull = any (== PersistNull) allVals notNullVals = filter (/= PersistNull) allVals allVals = filterValueToPersistValues value tn = escapeName conn $ entityDB $ entityDef $ dummyFromFilts [Filter field value pfilter] name = (if includeTable then ((tn ++ ".") ++) else id) $ escapeName conn $ fieldDB $ persistFieldDef field qmarks = case value of Left _ -> "?" Right x -> let x' = filter (/= PersistNull) $ map toPersistValue x in "(" ++ T.intercalate "," (map (const "?") x') ++ ")" varCount = case value of Left _ -> 1 Right x -> length x showSqlFilter Eq = "=" showSqlFilter Ne = "++" showSqlFilter Gt = ">" showSqlFilter Lt = "<" showSqlFilter Ge = ">=" showSqlFilter Le = "<=" showSqlFilter In = " IN " showSqlFilter NotIn = " NOT IN " showSqlFilter (BackendSpecificFilter s) = s infixr 5 ++ (++) :: Text -> Text -> Text (++) = mappend show :: Show a => a -> Text show = pack . Prelude.show -- | Equivalent to 'selectSource', but instead of getting the connection from -- the environment inside a 'SqlPersist' monad, provide an explicit -- 'Connection'. This can allow you to use the returned 'Source' in an -- arbitrary monad. selectSourceConn :: (PersistEntity val, MonadResource m, MonadLogger m, PersistEntityBackend val ~ R.SqlBackend, MonadBaseControl IO m) => Connection -> [Filter val] -> [SelectOpt val] -> Source m (Entity val) selectSourceConn conn fs opts = transPipe (flip runSqlConn conn) (selectSource fs opts) dummyFromFilts :: [Filter v] -> v dummyFromFilts _ = error "dummyFromFilts" orderClause :: PersistEntity val => Bool -- ^ include the table name -> Connection -> SelectOpt val -> Text orderClause includeTable conn o = case o of Asc x -> name $ persistFieldDef x Desc x -> name (persistFieldDef x) ++ " DESC" _ -> error $ "orderClause: expected Asc or Desc, not limit or offset" where dummyFromOrder :: SelectOpt a -> a dummyFromOrder _ = undefined tn = escapeName conn $ entityDB $ entityDef $ dummyFromOrder o name x = (if includeTable then ((tn ++ ".") ++) else id) $ escapeName conn $ fieldDB x