----------------------------------------------------------------------------------------- {-| Module : Database.HSQL.PostgreSQL Copyright : (c) Krasimir Angelov 2003 License : BSD-style Maintainer : ka2_mail@yahoo.com Stability : provisional Portability : portable The module provides interface to PostgreSQL database -} ----------------------------------------------------------------------------------------- module Database.HSQL.PostgreSQL(connect, connectWithOptions, module Database.HSQL) where import Database.HSQL import Database.HSQL.Types import Data.Dynamic import Data.Char import Foreign import Foreign.C import Control.OldException (throwDyn, catchDyn, dynExceptions, Exception(..)) import Control.Monad(when,unless,mplus) import Control.Concurrent.MVar import System.Time import System.IO.Unsafe import Text.ParserCombinators.ReadP import Text.Read import Numeric #include #include #include #include type PGconn = Ptr () type PGresult = Ptr () type ConnStatusType = #type ConnStatusType type ExecStatusType = #type ExecStatusType type Oid = #type Oid {-| Refer to PostgreSQL manual, chapter 30, `libpq - C library' (e.g. http://www.postgresql.org/docs/8.3/interactive/libpq.html) -} foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType foreign import ccall "libpq-fe.h PQerrorMessage" pqErrorMessage :: PGconn -> IO CString foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO () foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int foreign import ccall "strlen" strlen :: CString -> IO Int ----------------------------------------------------------------------------------------- -- Connect/Disconnect ----------------------------------------------------------------------------------------- -- | Makes a new connection to the database server connect :: String -- ^ Server name : port nr -> String -- ^ Database name -> String -- ^ User identifier -> String -- ^ Authentication string (password) -> IO Connection connect server database user authentication = do let (serverAddress,portInput)= break (==':') server port= if length portInput < 2 then Nothing else Just (tail portInput) connectWithOptions serverAddress port Nothing Nothing database user authentication -- | Makes a new connection to the database server, with specification of port, options & tty connectWithOptions :: String -- ^ Server name -> Maybe String -- ^ Port number -> Maybe String -- ^ Options -> Maybe String -- ^ TTY -> String -- ^ Database name -> String -- ^ User identifier -> String -- ^ Authentication string (password) -> IO Connection connectWithOptions server port options tty database user authentication = do pServer <- newCString server pPort <- newCStringElseNullPtr port pOptions <- newCStringElseNullPtr options pTty <- newCStringElseNullPtr tty pDatabase <- newCString database pUser <- newCString user pAuthentication <- newCString authentication pConn <- pqSetdbLogin pServer pPort pOptions pTty pDatabase pUser pAuthentication free pServer free pPort free pOptions free pTty free pUser free pAuthentication status <- pqStatus pConn unless (status == (#const CONNECTION_OK)) (do errMsg <- pqErrorMessage pConn >>= peekCString pqFinish pConn throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg})) refFalse <- newMVar False let connection = Connection { connDisconnect = pqFinish pConn , connExecute = execute pConn , connQuery = query connection pConn , connTables = tables connection pConn , connDescribe = describe connection pConn , connBeginTransaction = execute pConn "begin" , connCommitTransaction = execute pConn "commit" , connRollbackTransaction = execute pConn "rollback" , connClosed = refFalse } return connection where execute :: PGconn -> String -> IO () execute pConn sqlExpr = do pRes <- withCString sqlExpr (pqExec pConn) when (pRes==nullPtr) (do errMsg <- pqErrorMessage pConn >>= peekCString throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg})) status <- pqResultStatus pRes unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do errMsg <- pqResultErrorMessage pRes >>= peekCString throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg})) return () query :: Connection -> PGconn -> String -> IO Statement query conn pConn query = do pRes <- withCString query (pqExec pConn) when (pRes==nullPtr) (do errMsg <- pqErrorMessage pConn >>= peekCString throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg})) status <- pqResultStatus pRes unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do errMsg <- pqResultErrorMessage pRes >>= peekCString throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg})) defs <- if status == (#const PGRES_TUPLES_OK) then pgNFields pRes >>= getFieldDefs pRes 0 else return [] countTuples <- pqNTuples pRes; tupleIndex <- newMVar (-1) refFalse <- newMVar False return (Statement { stmtConn = conn , stmtClose = return () , stmtFetch = fetch tupleIndex countTuples , stmtGetCol = getColValue pRes tupleIndex countTuples , stmtFields = defs , stmtClosed = refFalse }) where getFieldDefs pRes i n | i >= n = return [] | otherwise = do name <- pgFName pRes i >>= peekCString dataType <- pqFType pRes i modifier <- pqFMod pRes i defs <- getFieldDefs pRes (i+1) n return ((name,mkSqlType dataType modifier,True):defs) mkSqlType :: Oid -> Int -> SqlType mkSqlType (#const BPCHAROID) size = SqlChar (size-4) mkSqlType (#const VARCHAROID) size = SqlVarChar (size-4) mkSqlType (#const NAMEOID) size = SqlVarChar 31 mkSqlType (#const TEXTOID) size = SqlText mkSqlType (#const NUMERICOID) size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000) mkSqlType (#const INT2OID) size = SqlSmallInt mkSqlType (#const INT4OID) size = SqlInteger mkSqlType (#const FLOAT4OID) size = SqlReal mkSqlType (#const FLOAT8OID) size = SqlDouble mkSqlType (#const BOOLOID) size = SqlBit mkSqlType (#const BITOID) size = SqlBinary size mkSqlType (#const VARBITOID) size = SqlVarBinary size mkSqlType (#const BYTEAOID) size = SqlTinyInt mkSqlType (#const INT8OID) size = SqlBigInt mkSqlType (#const DATEOID) size = SqlDate mkSqlType (#const TIMEOID) size = SqlTime mkSqlType (#const TIMETZOID) size = SqlTimeTZ mkSqlType (#const ABSTIMEOID) size = SqlAbsTime mkSqlType (#const RELTIMEOID) size = SqlRelTime mkSqlType (#const INTERVALOID) size = SqlTimeInterval mkSqlType (#const TINTERVALOID) size = SqlAbsTimeInterval mkSqlType (#const TIMESTAMPOID) size = SqlDateTime mkSqlType (#const TIMESTAMPTZOID) size = SqlDateTimeTZ mkSqlType (#const CASHOID) size = SqlMoney mkSqlType (#const INETOID) size = SqlINetAddr mkSqlType (#const 829) size = SqlMacAddr -- hack mkSqlType (#const CIDROID) size = SqlCIDRAddr mkSqlType (#const POINTOID) size = SqlPoint mkSqlType (#const LSEGOID) size = SqlLSeg mkSqlType (#const PATHOID) size = SqlPath mkSqlType (#const BOXOID) size = SqlBox mkSqlType (#const POLYGONOID) size = SqlPolygon mkSqlType (#const LINEOID) size = SqlLine mkSqlType (#const CIRCLEOID) size = SqlCircle mkSqlType tp size = SqlUnknown (fromIntegral tp) getFieldValue stmt colNumber fieldDef v = do mb_v <- stmtGetCol stmt colNumber fieldDef fromSqlCStringLen return (case mb_v of { Nothing -> v; Just a -> a }) tables :: Connection -> PGconn -> IO [String] tables connection pConn = do stmt <- query connection pConn "select relname from pg_class where relkind='r' and relname !~ '^pg_'" collectRows (\s -> getFieldValue s 0 ("relname", SqlVarChar 0, False) "") stmt describe :: Connection -> PGconn -> String -> IO [FieldDef] describe connection pConn table = do stmt <- query connection pConn ("select attname, atttypid, atttypmod, attnotnull " ++ "from pg_attribute as cols join pg_class as ts on cols.attrelid=ts.oid " ++ "where cols.attnum > 0 and ts.relname="++toSqlValue table++ " and cols.attisdropped = False ") collectRows getColumnInfo stmt where getColumnInfo stmt = do column_name <- getFieldValue stmt 0 ("attname", SqlVarChar 0, False) "" data_type <- getFieldValue stmt 1 ("atttypid", SqlInteger, False) 0 type_mod <- getFieldValue stmt 2 ("atttypmod", SqlInteger, False) 0 notnull <- getFieldValue stmt 3 ("attnotnull", SqlBit, False) False let sqlType = mkSqlType (fromIntegral (data_type :: Int)) (fromIntegral (type_mod :: Int)) return (column_name, sqlType, not notnull) fetch :: MVar Int -> Int -> IO Bool fetch tupleIndex countTuples = modifyMVar tupleIndex (\index -> return (index+1,index < countTuples-1)) getColValue :: PGresult -> MVar Int -> Int -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a getColValue pRes tupleIndex countTuples colNumber fieldDef f = do index <- readMVar tupleIndex when (index >= countTuples) (throwDyn SqlNoData) isnull <- pqGetisnull pRes index colNumber if isnull == 1 then f fieldDef nullPtr 0 else do pStr <- pqGetvalue pRes index colNumber strLen <- strlen pStr f fieldDef pStr strLen -- | Convert string by newCString, if provided, else return of nullPtr newCStringElseNullPtr :: Maybe String -> IO CString newCStringElseNullPtr Nothing = return nullPtr newCStringElseNullPtr (Just string) = newCString string