{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE PackageImports #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE DeriveDataTypeable #-} module Database.Persist.GenericSql.Raw ( withStmt , execute , executeCount , SqlPersist (..) , getStmt' , getStmt , SqlBackend , MonadSqlPersist (..) , StatementAlreadyFinalized (..) ) where import qualified Database.Persist.GenericSql.Internal as I import Database.Persist.GenericSql.Internal hiding (execute, withStmt) import Database.Persist.Store (PersistValue) import Data.IORef import Control.Monad.IO.Class import Control.Monad.Logger (logDebugS) import Control.Monad.Trans.Reader import qualified Data.Map as Map import Control.Applicative (Applicative) import Control.Monad.Trans.Class (MonadTrans (..)) import Control.Monad.Base (MonadBase (liftBase)) import Control.Monad.Trans.Control (MonadBaseControl (..), ComposeSt, defaultLiftBaseWith, defaultRestoreM, MonadTransControl (..)) import Control.Monad (liftM, when) import Control.Exception (throwIO, Exception) #define MBCIO MonadBaseControl IO import Data.Text (Text, pack) import Control.Monad (MonadPlus) import Control.Monad.Trans.Resource (MonadResource (..)) import Data.Conduit #if MIN_VERSION_conduit(1, 0, 0) import Data.Conduit.Internal (Pipe, ConduitM) #endif import Control.Monad.Logger (MonadLogger (..)) import Data.Monoid (Monoid) import Data.Typeable (Typeable) import Data.Int (Int64) import Control.Monad.Logger (LoggingT) import Control.Monad.Trans.Identity ( IdentityT) import Control.Monad.Trans.List ( ListT ) import Control.Monad.Trans.Maybe ( MaybeT ) import Control.Monad.Trans.Error ( ErrorT, Error) import Control.Monad.Trans.Cont ( ContT ) import Control.Monad.Trans.State ( StateT ) import Control.Monad.Trans.Writer ( WriterT ) import Control.Monad.Trans.RWS ( RWST ) import qualified Control.Monad.Trans.RWS.Strict as Strict ( RWST ) import qualified Control.Monad.Trans.State.Strict as Strict ( StateT ) import qualified Control.Monad.Trans.Writer.Strict as Strict ( WriterT ) data SqlBackend newtype SqlPersist m a = SqlPersist { unSqlPersist :: ReaderT Connection m a } deriving (Monad, MonadIO, MonadTrans, Functor, Applicative, MonadPlus) instance MonadThrow m => MonadThrow (SqlPersist m) where monadThrow = lift . monadThrow instance MonadBase backend m => MonadBase backend (SqlPersist m) where liftBase = lift . liftBase instance MonadBaseControl backend m => MonadBaseControl backend (SqlPersist m) where newtype StM (SqlPersist m) a = StMSP {unStMSP :: ComposeSt SqlPersist m a} liftBaseWith = defaultLiftBaseWith StMSP restoreM = defaultRestoreM unStMSP instance MonadTransControl SqlPersist where newtype StT SqlPersist a = StReader {unStReader :: a} liftWith f = SqlPersist $ ReaderT $ \r -> f $ \t -> liftM StReader $ runReaderT (unSqlPersist t) r restoreT = SqlPersist . ReaderT . const . liftM unStReader instance MonadResource m => MonadResource (SqlPersist m) where #if MIN_VERSION_resourcet(0,4,0) liftResourceT = lift . liftResourceT #else register = lift . register release = lift . release allocate a = lift . allocate a resourceMask = lift . resourceMask #endif class (MonadIO m, MonadLogger m) => MonadSqlPersist m where askSqlConn :: m Connection instance (MonadIO m, MonadLogger m) => MonadSqlPersist (SqlPersist m) where askSqlConn = SqlPersist ask #define GO(T) instance (MonadSqlPersist m) => MonadSqlPersist (T m) where askSqlConn = lift askSqlConn #define GOX(X, T) instance (X, MonadSqlPersist m) => MonadSqlPersist (T m) where askSqlConn = lift askSqlConn GO(LoggingT) GO(IdentityT) GO(ListT) GO(MaybeT) GOX(Error e, ErrorT e) GO(ReaderT r) GO(ContT r) GO(StateT s) GO(ResourceT) GO(Pipe l i o u) #if MIN_VERSION_conduit(1, 0, 0) GO(ConduitM i o) #endif GOX(Monoid w, WriterT w) GOX(Monoid w, RWST r w s) GOX(Monoid w, Strict.RWST r w s) GO(Strict.StateT s) GOX(Monoid w, Strict.WriterT w) #undef GO #undef GOX instance MonadLogger m => MonadLogger (SqlPersist m) where #if MIN_VERSION_monad_logger(0, 3, 0) monadLoggerLog a b c = lift . monadLoggerLog a b c #else monadLoggerLog a b c = lift $ monadLoggerLog a b c monadLoggerLogSource a b c = lift . monadLoggerLogSource a b c #endif withStmt :: (MonadSqlPersist m, MonadResource m) => Text -> [PersistValue] -> Source m [PersistValue] withStmt sql vals = do lift $ $logDebugS (pack "SQL") $ pack $ show sql ++ " " ++ show vals conn <- lift askSqlConn bracketP (getStmt' conn sql) I.reset (flip I.withStmt vals) execute :: MonadSqlPersist m => Text -> [PersistValue] -> m () execute x y = liftM (const ()) $ executeCount x y executeCount :: MonadSqlPersist m => Text -> [PersistValue] -> m Int64 executeCount sql vals = do $logDebugS (pack "SQL") $ pack $ show sql ++ " " ++ show vals stmt <- getStmt sql res <- liftIO $ I.execute stmt vals liftIO $ reset stmt return res getStmt :: MonadSqlPersist m => Text -> m Statement getStmt sql = do conn <- askSqlConn liftIO $ getStmt' conn sql getStmt' :: Connection -> Text -> IO Statement getStmt' conn sql = do smap <- liftIO $ readIORef $ stmtMap conn case Map.lookup sql smap of Just stmt -> return stmt Nothing -> do stmt' <- liftIO $ prepare conn sql iactive <- liftIO $ newIORef True let stmt = I.Statement { finalize = do active <- readIORef iactive if active then do finalize stmt' writeIORef iactive False else return () , reset = do active <- readIORef iactive when active $ reset stmt' , I.execute = \x -> do active <- readIORef iactive if active then I.execute stmt' x else throwIO $ StatementAlreadyFinalized sql , I.withStmt = \x -> do active <- liftIO $ readIORef iactive if active then I.withStmt stmt' x else liftIO $ throwIO $ StatementAlreadyFinalized sql } liftIO $ writeIORef (stmtMap conn) $ Map.insert sql stmt smap return stmt data StatementAlreadyFinalized = StatementAlreadyFinalized Text deriving (Typeable, Show) instance Exception StatementAlreadyFinalized