{-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, RecordWildCards, OverloadedStrings #-}
module Database.PostgreSQL.Transact where
import Control.Monad.Trans.Reader
import Database.PostgreSQL.Simple as Simple
import Database.PostgreSQL.Simple.Types as Simple
import Database.PostgreSQL.Simple.Transaction
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Catch
import Data.Int
import Control.Monad
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC
import qualified Control.Monad.Fail as Fail
newtype DBT m a = DBT { unDBT :: ReaderT Connection m a }
deriving (MonadTrans, MonadThrow)
type DB = DBT IO
instance Functor m => Functor (DBT m) where
fmap f = DBT . fmap f . unDBT
instance Applicative m => Applicative (DBT m) where
pure = DBT . pure
f <*> v = DBT $ unDBT f <*> unDBT v
instance MonadIO m => MonadIO (DBT m) where
liftIO = lift . liftIO
instance Monad m => Monad (DBT m) where
return = lift . return
DBT m >>= k = DBT $ m >>= unDBT . k
instance Fail.MonadFail m => Fail.MonadFail (DBT m) where
fail = lift . Fail.fail
isClass25 :: SqlError -> Bool
isClass25 SqlError{..} = BS.take 2 sqlState == "25"
instance (MonadIO m, MonadMask m) => MonadCatch (DBT m) where
catch (DBT act) handler = DBT $ mask $ \restore -> do
conn <- ask
sp <- liftIO $ newSavepoint conn
let setup = catch (restore act) $ \e -> do
liftIO $ rollbackToSavepoint conn sp
unDBT $ handler e
setup `finally` liftIO (tryJust (guard . isClass25) (releaseSavepoint conn sp))
getConnection :: Monad m => DBT m Connection
getConnection = DBT ask
runDBT :: MonadBaseControl IO m => DBT m a -> IsolationLevel -> Connection -> m a
runDBT action level conn
= control
$ \run -> withTransactionLevel level conn
$ run
$ runReaderT (unDBT action) conn
runDBTSerializable :: MonadBaseControl IO m => DBT m a -> Connection -> m a
runDBTSerializable action conn
= control
$ \run -> withTransactionSerializable conn
$ run
$ runReaderT (unDBT action) conn
query :: (ToRow a, FromRow b, MonadIO m) => Query -> a -> DBT m [b]
query q x = getConnection >>= \conn -> liftIO $ Simple.query conn q x
query_ :: (FromRow b, MonadIO m) => Query -> DBT m [b]
query_ q = getConnection >>= \conn -> liftIO $ Simple.query_ conn q
execute :: (ToRow q, MonadIO m) => Query -> q -> DBT m Int64
execute q x = getConnection >>= \conn -> liftIO $ Simple.execute conn q x
execute_ :: MonadIO m => Query -> DBT m Int64
execute_ q = getConnection >>= \conn -> liftIO $ Simple.execute_ conn q
executeMany :: (ToRow q, MonadIO m) => Query -> [q] -> DBT m Int64
executeMany q xs = getConnection >>= \conn -> liftIO $ Simple.executeMany conn q xs
returning :: (ToRow q, FromRow r, MonadIO m) => Query -> [q] -> DBT m [r]
returning q xs = getConnection >>= \conn -> liftIO $ Simple.returning conn q xs
formatQuery :: (ToRow q, MonadIO m) => Query -> q -> DBT m BS.ByteString
formatQuery q xs = getConnection >>= \conn -> liftIO $ Simple.formatQuery conn q xs
newtype TooManyRows = TooManyRows String
deriving(Show, Eq)
instance Exception TooManyRows
queryOne :: (ToRow a, FromRow b) => Query -> a -> DB (Maybe b)
queryOne q x = do
rows <- Database.PostgreSQL.Transact.query q x
case rows of
[] -> return Nothing
[a] -> return $ Just a
_ -> do
let Simple.Query str = q
throwM $ TooManyRows $ BSC.unpack str
queryOne_ :: FromRow b => Query -> DB (Maybe b)
queryOne_ q = do
rows <- Database.PostgreSQL.Transact.query_ q
case rows of
[] -> return Nothing
[x] -> return $ Just x
_ -> do
let Simple.Query str = q
throwM $ TooManyRows $ BSC.unpack str