module Database.PostgreSQL where
import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Trans
import Data.List
import Data.Maybe
import Foreign
import Foreign.C
newtype DatabaseT m a = DatabaseT (StateT DatabaseHandle m a)
deriving (Monad, MonadIO, MonadTrans)
class (MonadIO m, Error e, MonadError e m)
=> MonadDatabase e m where
getConnection :: m DatabaseHandle
instance (MonadIO m, Error e, MonadError e m)
=> MonadDatabase e (DatabaseT m) where
getConnection = DatabaseT get
instance (MonadIO (t m), MonadError e (t m), MonadTrans t,
MonadDatabase e m, Monad (t m))
=> MonadDatabase e (t m) where
getConnection = lift getConnection
instance MonadError e m => MonadError e (DatabaseT m) where
throwError x = DatabaseT (throwError x)
catchError (DatabaseT f) g = DatabaseT
$ catchError f (\x -> case g x of
DatabaseT y -> y)
data PGconn
newtype DatabaseHandle = DatabaseHandle (Ptr PGconn)
foreign import ccall unsafe "static libpq-fe.h PQconnectdb"
pqConnectDB :: CString -> IO (Ptr PGconn)
type ConnStatusType = Word32
connection_OK :: ConnStatusType
connection_OK = 0
connection_bad :: ConnStatusType
connection_bad = 1
foreign import ccall unsafe "static libpq-fe.h PQstatus"
pqStatus :: Ptr PGconn -> IO ConnStatusType
foreign import ccall unsafe "static libpq-fe.h PQerrorMessage"
pqErrorMessage :: Ptr PGconn -> IO CString
data PGresult
foreign import ccall unsafe "static libpq-fe.h PQexec"
pqExec :: Ptr PGconn -> CString -> IO (Ptr PGresult)
execute :: MonadDatabase e m => String -> m ()
execute sql = do DatabaseHandle dbh <- getConnection
checkResultStatus "execute" $ withCString sql $ pqExec dbh
type Oid = Word32
withCStrings :: [String] -> (Ptr CString -> IO a) -> IO a
withCStrings all_xs f = go [] all_xs
where go acc [] = withArray (reverse acc) f
go acc (x:xs) = withCString x $ \s -> go (s:acc) xs
foreign import ccall unsafe "static libpq-fe.h PQexecParams"
pqExecParams :: Ptr PGconn
-> CString
-> CInt
-> Ptr Oid
-> Ptr CString
-> Ptr CInt
-> Ptr CInt
-> CInt
-> IO (Ptr PGresult)
execParams :: MonadDatabase e m => String -> [String] -> m ()
execParams sql params
= do let nparams = genericLength params
oids = nullPtr
lengths = nullPtr
DatabaseHandle dbh <- getConnection
checkResultStatus "execParams" $
withCString sql $ \sql' ->
withCStrings params $ \params' ->
withArray (genericReplicate nparams 0) $ \formats ->
pqExecParams dbh sql' nparams oids params' lengths formats 0
type ExecStatusType = Word32
pgres_empty_query :: ExecStatusType
pgres_empty_query = 0
pgres_command_OK :: ExecStatusType
pgres_command_OK = 1
pgres_tuples_OK :: ExecStatusType
pgres_tuples_OK = 2
foreign import ccall unsafe "static libpq-fe.h PQresultStatus"
pqResultStatus :: Ptr PGresult -> IO ExecStatusType
foreign import ccall unsafe "static libpq-fe.h PQresStatus"
pqResStatus :: ExecStatusType -> IO CString
foreign import ccall unsafe "static libpq-fe.h PQresultErrorMessage"
pqResultErrorMessage :: Ptr PGresult -> IO CString
checkResultStatus :: MonadDatabase e m => String -> IO (Ptr PGresult) -> m ()
checkResultStatus s f
= do res <- liftIO f
res' <- liftIO $ pqResultStatus res
when (res' `notElem` [pgres_command_OK, pgres_tuples_OK]) $ do
err_msg <- liftIO $ pqResultErrorMessage res >>= peekCString
err_code <- liftIO $ pqResStatus res' >>= peekCString
let err = s ++ " failed (" ++ err_code ++ "): " ++ err_msg
throwError $ strMsg err
foreign import ccall unsafe "static libpq-fe.h PQfinish"
pqFinish :: Ptr PGconn -> IO ()
withDatabaseRaw :: MonadIO m => String -> DatabaseT m a -> m a
withDatabaseRaw conninfo (DatabaseT f)
= do dbh <- liftIO $ withCString conninfo pqConnectDB
if dbh == nullPtr
then error "XXX dbh was NULL - can't happen?"
else do stat <- liftIO $ pqStatus dbh
if stat /= connection_OK
then do err <- liftIO $ pqErrorMessage dbh >>= peekCString
error err
else do res <- evalStateT f (DatabaseHandle dbh)
liftIO $ pqFinish dbh
return res
data ConnectionInfo = ConnectionInfo { host :: Maybe String,
hostaddr :: Maybe String,
port :: Maybe String,
dbname :: Maybe String,
user :: Maybe String,
password :: Maybe String,
connect_timeout :: Maybe String,
options :: Maybe String,
sslmode :: Maybe String,
service :: Maybe String }
defaultConnectionInfo :: ConnectionInfo
defaultConnectionInfo = ConnectionInfo { host = Nothing,
hostaddr = Nothing,
port = Nothing,
dbname = Nothing,
user = Nothing,
password = Nothing,
connect_timeout = Nothing,
options = Nothing,
sslmode = Nothing,
service = Nothing }
withDatabase :: MonadIO m => ConnectionInfo -> DatabaseT m a -> m a
withDatabase conninfo f
= withDatabaseRaw conninfo' f
where conninfo' = concat $ intersperse " " $ catMaybes [
mkSetting "host" host,
mkSetting "hostaddr" hostaddr,
mkSetting "port" port,
mkSetting "dbname" dbname,
mkSetting "user" user,
mkSetting "password" password,
mkSetting "connect_timeout" connect_timeout,
mkSetting "options" options,
mkSetting "sslmode" sslmode,
mkSetting "service" service]
mkSetting name extract
= case extract conninfo of
Just val -> Just (name ++ "='" ++ escape val ++ "'")
Nothing -> Nothing
escape ('\'':cs) = '\\':'\'':escape cs
escape ('\\':cs) = '\\':'\\':escape cs
escape (c:cs) = c:escape cs
escape "" = ""