{-# LANGUAGE PackageImports #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} module Database.Persist.GenericSql.Raw ( withStmt , execute , SqlPersist (..) , getStmt' , getStmt ) where import qualified Database.Persist.GenericSql.Internal as I import Database.Persist.GenericSql.Internal hiding (execute, withStmt) import Database.Persist.Base (PersistValue) import Data.IORef import Control.Monad.IO.Class import Control.Monad.Trans.Reader import qualified Data.Map as Map import Control.Applicative (Applicative) import Control.Monad.Trans.Class (MonadTrans (..)) import Control.Monad.Invert (MonadInvertIO (..)) import Control.Monad (liftM) newtype SqlPersist m a = SqlPersist { unSqlPersist :: ReaderT Connection m a } deriving (Monad, MonadIO, MonadTrans, Functor, Applicative) instance MonadInvertIO m => MonadInvertIO (SqlPersist m) where newtype InvertedIO (SqlPersist m) a = InvSqlPersistIO { runInvSqlPersistIO :: InvertedIO (ReaderT Connection m) a } type InvertedArg (SqlPersist m) = (Connection, InvertedArg m) invertIO = liftM (fmap InvSqlPersistIO) . invertIO . unSqlPersist revertIO f = SqlPersist $ revertIO $ liftM runInvSqlPersistIO . f withStmt :: MonadInvertIO m => String -> [PersistValue] -> (RowPopper (SqlPersist m) -> SqlPersist m a) -> SqlPersist m a withStmt sql vals pop = do stmt <- getStmt sql ret <- I.withStmt stmt vals pop liftIO $ reset stmt return ret execute :: MonadIO m => String -> [PersistValue] -> SqlPersist m () execute sql vals = do stmt <- getStmt sql liftIO $ I.execute stmt vals liftIO $ reset stmt getStmt :: MonadIO m => String -> SqlPersist m Statement getStmt sql = do conn <- SqlPersist ask liftIO $ getStmt' conn sql getStmt' :: Connection -> String -> IO Statement getStmt' conn sql = do smap <- liftIO $ readIORef $ stmtMap conn case Map.lookup sql smap of Just stmt -> return stmt Nothing -> do stmt <- liftIO $ prepare conn sql liftIO $ writeIORef (stmtMap conn) $ Map.insert sql stmt smap return stmt