{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Traction.Sql ( Only (..) , Binary (..) , mandatory , mandatory_ , unique , unique_ , query , query_ , execute , execute_ , value , valueWith , values , valuesWith , Schema (..) , newSchema , Savepoint (..) , newSavepoint , createSavepoint , releaseSavepoint , rollbackSavepoint , Unique (..) , isUnique , isDuplicate , withUniqueCheck , withUniqueCheckSavepoint ) where import Control.Monad.Trans.Class (MonadTrans (..)) import Control.Monad.Trans.Reader (ask, runReaderT) import Control.Monad.IO.Class (MonadIO (..)) import qualified Data.ByteString.Builder as ByteStringBuilder import qualified Data.Char as Char import Data.Data (Data) import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import Data.Typeable (Typeable) import Database.PostgreSQL.Simple.SqlQQ (sql) import Database.PostgreSQL.Simple (ToRow, FromRow, Only (..), Binary (..)) import qualified Database.PostgreSQL.Simple as Postgresql import qualified Database.PostgreSQL.Simple.ToField as Postgresql import Traction.Control import Traction.Prelude mandatory :: (MonadDb m, ToRow a, FromRow b) => Postgresql.Query -> a -> m b mandatory q = liftDb . definitely q . unique q mandatory_ :: (MonadDb m, FromRow a) => Postgresql.Query -> m a mandatory_ q = liftDb . definitely q $ unique_ q unique :: (MonadDb m, ToRow a, FromRow b) => Postgresql.Query -> a -> m (Maybe b) unique q parameters = liftDb . possibly q . withConnection q $ \c -> Postgresql.query c q parameters unique_ :: (MonadDb m, FromRow a) => Postgresql.Query -> m (Maybe a) unique_ q = liftDb . possibly q . withConnection q $ \c -> Postgresql.query_ c q query :: (MonadDb m, ToRow a, FromRow b) => Postgresql.Query -> a -> m [b] query q parameters = liftDb . withConnection q $ \c -> Postgresql.query c q parameters query_ :: (MonadDb m, FromRow a) => Postgresql.Query -> m [a] query_ q = liftDb . withConnection q $ \c -> Postgresql.query_ c q execute :: (MonadDb m, ToRow a) => Postgresql.Query -> a -> m Int64 execute q parameters = liftDb . withConnection q $ \c -> Postgresql.execute c q parameters execute_ :: MonadDb m => Postgresql.Query -> m Int64 execute_ q = liftDb . withConnection q $ \c -> Postgresql.execute_ c q possibly :: Postgresql.Query -> Db [a] -> Db (Maybe a) possibly q db = db >>= \as -> case as of [] -> pure Nothing [x] -> pure . Just $ x (_x:_xs) -> failWith $ DbTooManyResults q (length as) definitely :: Postgresql.Query -> Db (Maybe a) -> Db a definitely q db = db >>= fromMaybeM (failWith $ DbNoResults q) value :: Functor f => f (Only a) -> f a value = fmap fromOnly valueWith :: Functor f => (a -> b) -> f (Only a) -> f b valueWith f = fmap (f . fromOnly) values :: (Functor f, Functor g) => g (f (Only a)) -> g (f a) values = (fmap . fmap) fromOnly valuesWith :: (Functor f, Functor g) => (a -> b) -> g (f (Only a)) -> g (f b) valuesWith f = (fmap . fmap) (f . fromOnly) newtype Schema = Schema { renderSchema :: Text } deriving (Eq, Show, Data, Typeable) instance Postgresql.ToField Schema where toField = Postgresql.EscapeIdentifier . Text.encodeUtf8 . renderSchema newSchema :: Text -> Maybe Schema newSchema t = if Text.all (\c -> Char.isAlpha c || c == '_') t then Just (Schema t) else Nothing newtype Savepoint = Savepoint { renderSavepoint :: Text } deriving (Eq, Show, Data, Typeable) instance Postgresql.ToField Savepoint where toField = Postgresql.Plain . ByteStringBuilder.byteString . Text.encodeUtf8 . renderSavepoint newSavepoint :: Text -> Maybe Savepoint newSavepoint t = if Text.all (\c -> Char.isAlpha c || c == '_') t then Just (Savepoint t) else Nothing createSavepoint :: Savepoint -> Db () createSavepoint s = void $ execute [sql| SAVEPOINT ? |] (Only s) releaseSavepoint :: Savepoint -> Db () releaseSavepoint s = void $ execute [sql| RELEASE SAVEPOINT ? |] (Only s) rollbackSavepoint :: Savepoint -> Db () rollbackSavepoint s = void $ execute [sql| ROLLBACK TO SAVEPOINT ? |] (Only s) bracketSavepoint :: Savepoint -> Db a -> Db a bracketSavepoint savepoint db = Db $ ask >>= \c -> lift $ do r <- liftIO . runEitherT $ flip runReaderT c $ _runDb (createSavepoint savepoint >> db) case r of Left _ -> do flip runReaderT c $ _runDb (rollbackSavepoint savepoint) newEitherT $ pure r Right _ -> do flip runReaderT c $ _runDb (releaseSavepoint savepoint) newEitherT $ pure r data Unique a = Unique a | Duplicate Postgresql.Query Postgresql.SqlError deriving (Show, Functor) isUnique :: Unique a -> Bool isUnique u = case u of Unique _ -> True Duplicate _ _ -> False isDuplicate :: Unique a -> Bool isDuplicate = not . isUnique withUniqueCheck :: MonadDb m => Db a -> m (Unique a) withUniqueCheck = withUniqueCheckSavepoint (Savepoint "duplicate_key_savepoint") withUniqueCheckSavepoint :: MonadDb m => Savepoint -> Db a -> m (Unique a) withUniqueCheckSavepoint savepoint db = liftDb . Db $ ask >>= \c -> lift $ do r <- liftIO . runEitherT $ flip runReaderT c $ _runDb (bracketSavepoint savepoint db) case r of Left (DbSqlError q e) -> do if Postgresql.sqlState e == "23505" then pure (Duplicate q e) else fmap Unique $ newEitherT $ pure r _ -> fmap Unique $ newEitherT $ pure r