{-# LANGUAGE CPP               #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.PostgreSQL.Simple.Util
    ( existsTable
    , withTransactionRolledBack
    ) where
import           Control.Exception          (finally)
import           Database.PostgreSQL.Simple (Connection, Only (..), begin,
                                             query, rollback)
import           GHC.Int                    (Int64)
existsTable :: Connection -> String -> IO Bool
existsTable con table =
    fmap checkRowCount (query con q (Only table) :: IO [[Int64]])
    where
        q = "select count(relname) from pg_class where relname = ?"
        checkRowCount :: [[Int64]] -> Bool
        checkRowCount ((1:_):_) = True
        checkRowCount _         = False
withTransactionRolledBack :: Connection -> IO a -> IO a
withTransactionRolledBack con f =
    begin con >> finally f (rollback con)