-- |
-- Module      : Database.PostgreSQL.Simple.Util
-- Copyright   : (c) 2014 Andreas Meingast <ameingast@gmail.com>
--
-- License     : BSD-style
-- Maintainer  : andre@andrevdm.com
-- Stability   : experimental
-- Portability : GHC
--
-- A collection of utilites for database migrations.

{-# LANGUAGE OverloadedStrings #-}

module Database.PostgreSQL.Simple.Util
    ( existsTable
    , withTransactionRolledBack
    ) where

import           Control.Exception ( finally )
import           Database.PostgreSQL.Simple ( Connection
                                            , Only (..)
                                            , begin
                                            , query
                                            , rollback
                                            )
import           GHC.Int (Int64)

-- | Checks if the table with the given name exists in the database.
existsTable :: Connection -> String -> IO Bool
existsTable :: Connection -> String -> IO Bool
existsTable Connection
con String
table =
  [[Int64]] -> Bool
checkRowCount ([[Int64]] -> Bool) -> IO [[Int64]] -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Connection -> Query -> Only String -> IO [[Int64]]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
con Query
q (String -> Only String
forall a. a -> Only a
Only String
table) :: IO [[Int64]])
  where
    q :: Query
q = Query
"select count(relname) from pg_class where relname = ?"

    checkRowCount :: [[Int64]] -> Bool
    checkRowCount :: [[Int64]] -> Bool
checkRowCount ((Int64
1:[Int64]
_):[[Int64]]
_) = Bool
True
    checkRowCount [[Int64]]
_ = Bool
False

-- | Executes the given IO monad inside a transaction and performs a roll-back
-- afterwards (even if exceptions occur).
withTransactionRolledBack :: Connection -> IO a -> IO a
withTransactionRolledBack :: forall a. Connection -> IO a -> IO a
withTransactionRolledBack Connection
con IO a
f =
  Connection -> IO ()
begin Connection
con IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
finally IO a
f (Connection -> IO ()
rollback Connection
con)