{-# LANGUAGE OverloadedStrings #-} module Database.DBCleaner.PostgreSQLSimple (withConnection) where import Control.Exception (bracket, finally) import Control.Monad (void) import Data.Monoid (mconcat) import Data.Text (Text) import Database.PostgreSQL.Simple (Connection, Only (..), begin, execute, query_, rollback) import Database.PostgreSQL.Simple.Types (Identifier (..)) import Database.DBCleaner.Types (Strategy (..)) -- | Connection wrapper that cleans up the database using the strategy specified. withConnection :: Strategy -> (Connection -> IO a) -> Connection -> IO a withConnection Transaction f c = bracket (begin c >> return c) rollback f withConnection Truncation f c = finally (f c) $ listTables c >>= mapM_ (truncateTable c) listTables :: Connection -> IO [Text] listTables c = map fromOnly `fmap` query_ c q where q = mconcat [ "SELECT c.relname FROM pg_catalog.pg_class c" , " LEFT JOIN pg_catalog.pg_namespace n" , " ON c.relnamespace = n.oid" , " WHERE c.relkind IN ('r', '')" , " AND n.nspname <> 'pg_catalog'" , " AND n.nspname <> 'information_schema'" , " AND n.nspname !~ '^pg_toast'" , " AND pg_catalog.pg_table_is_visible(c.oid)" ] truncateTable :: Connection -> Text -> IO () truncateTable c = void . execute c q . Only . Identifier where q = "TRUNCATE ? CASCADE"