module Database.HSQL.PostgreSQL(connect,
connectWithOptions,
module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Char
import Foreign
import Foreign.C
import Control.OldException (throwDyn, catchDyn, dynExceptions, Exception(..))
import Control.Monad(when,unless,mplus)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
import Numeric
type PGconn = Ptr ()
type PGresult = Ptr ()
type ConnStatusType = Word32
type ExecStatusType = Word32
type Oid = Word32
foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn
foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType
foreign import ccall "libpq-fe.h PQerrorMessage" pqErrorMessage :: PGconn -> IO CString
foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO ()
foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult
foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType
foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString
foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString
foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString
foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid
foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int
foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int
foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString
foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int
foreign import ccall "strlen" strlen :: CString -> IO Int
connect :: String
-> String
-> String
-> String
-> IO Connection
connect server database user authentication = do
let (serverAddress,portInput)= break (==':') server
port= if length portInput < 2
then Nothing
else Just (tail portInput)
connectWithOptions serverAddress
port
Nothing
Nothing
database
user
authentication
connectWithOptions :: String
-> Maybe String
-> Maybe String
-> Maybe String
-> String
-> String
-> String
-> IO Connection
connectWithOptions server port options tty database user authentication = do
pServer <- newCString server
pPort <- newCStringElseNullPtr port
pOptions <- newCStringElseNullPtr options
pTty <- newCStringElseNullPtr tty
pDatabase <- newCString database
pUser <- newCString user
pAuthentication <- newCString authentication
pConn <- pqSetdbLogin pServer pPort pOptions pTty pDatabase pUser pAuthentication
free pServer
free pPort
free pOptions
free pTty
free pUser
free pAuthentication
status <- pqStatus pConn
unless (status == (0)) (do
errMsg <- pqErrorMessage pConn >>= peekCString
pqFinish pConn
throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
refFalse <- newMVar False
let connection = Connection
{ connDisconnect = pqFinish pConn
, connExecute = execute pConn
, connQuery = query connection pConn
, connTables = tables connection pConn
, connDescribe = describe connection pConn
, connBeginTransaction = execute pConn "begin"
, connCommitTransaction = execute pConn "commit"
, connRollbackTransaction = execute pConn "rollback"
, connClosed = refFalse
}
return connection
where
execute :: PGconn -> String -> IO ()
execute pConn sqlExpr = do
pRes <- withCString sqlExpr (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(7), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (1) || status == (2)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
return ()
query :: Connection -> PGconn -> String -> IO Statement
query conn pConn query = do
pRes <- withCString query (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(7), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (1) || status == (2)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
defs <- if status == (2) then pgNFields pRes >>= getFieldDefs pRes 0 else return []
countTuples <- pqNTuples pRes;
tupleIndex <- newMVar (1)
refFalse <- newMVar False
return (Statement
{ stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch tupleIndex countTuples
, stmtGetCol = getColValue pRes tupleIndex countTuples
, stmtFields = defs
, stmtClosed = refFalse
})
where
getFieldDefs pRes i n
| i >= n = return []
| otherwise = do
name <- pgFName pRes i >>= peekCString
dataType <- pqFType pRes i
modifier <- pqFMod pRes i
defs <- getFieldDefs pRes (i+1) n
return ((name,mkSqlType dataType modifier,True):defs)
mkSqlType :: Oid -> Int -> SqlType
mkSqlType (1042) size = SqlChar (size4)
mkSqlType (1043) size = SqlVarChar (size4)
mkSqlType (19) size = SqlVarChar 31
mkSqlType (25) size = SqlText
mkSqlType (1700) size = SqlNumeric ((size4) `div` 0x10000) ((size4) `mod` 0x10000)
mkSqlType (21) size = SqlSmallInt
mkSqlType (23) size = SqlInteger
mkSqlType (700) size = SqlReal
mkSqlType (701) size = SqlDouble
mkSqlType (16) size = SqlBit
mkSqlType (1560) size = SqlBinary size
mkSqlType (1562) size = SqlVarBinary size
mkSqlType (17) size = SqlTinyInt
mkSqlType (20) size = SqlBigInt
mkSqlType (1082) size = SqlDate
mkSqlType (1083) size = SqlTime
mkSqlType (1266) size = SqlTimeTZ
mkSqlType (702) size = SqlAbsTime
mkSqlType (703) size = SqlRelTime
mkSqlType (1186) size = SqlTimeInterval
mkSqlType (704) size = SqlAbsTimeInterval
mkSqlType (1114) size = SqlDateTime
mkSqlType (1184) size = SqlDateTimeTZ
mkSqlType (790) size = SqlMoney
mkSqlType (869) size = SqlINetAddr
mkSqlType (829) size = SqlMacAddr
mkSqlType (650) size = SqlCIDRAddr
mkSqlType (600) size = SqlPoint
mkSqlType (601) size = SqlLSeg
mkSqlType (602) size = SqlPath
mkSqlType (603) size = SqlBox
mkSqlType (604) size = SqlPolygon
mkSqlType (628) size = SqlLine
mkSqlType (718) size = SqlCircle
mkSqlType tp size = SqlUnknown (fromIntegral tp)
getFieldValue stmt colNumber fieldDef v = do
mb_v <- stmtGetCol stmt colNumber fieldDef fromSqlCStringLen
return (case mb_v of { Nothing -> v; Just a -> a })
tables :: Connection -> PGconn -> IO [String]
tables connection pConn = do
stmt <- query connection pConn "select relname from pg_class where relkind='r' and relname !~ '^pg_'"
collectRows (\s -> getFieldValue s 0 ("relname", SqlVarChar 0, False) "") stmt
describe :: Connection -> PGconn -> String -> IO [FieldDef]
describe connection pConn table = do
stmt <- query connection pConn
("select attname, atttypid, atttypmod, attnotnull " ++
"from pg_attribute as cols join pg_class as ts on cols.attrelid=ts.oid " ++
"where cols.attnum > 0 and ts.relname="++toSqlValue table++
" and cols.attisdropped = False ")
collectRows getColumnInfo stmt
where
getColumnInfo stmt = do
column_name <- getFieldValue stmt 0 ("attname", SqlVarChar 0, False) ""
data_type <- getFieldValue stmt 1 ("atttypid", SqlInteger, False) 0
type_mod <- getFieldValue stmt 2 ("atttypmod", SqlInteger, False) 0
notnull <- getFieldValue stmt 3 ("attnotnull", SqlBit, False) False
let sqlType = mkSqlType (fromIntegral (data_type :: Int)) (fromIntegral (type_mod :: Int))
return (column_name, sqlType, not notnull)
fetch :: MVar Int -> Int -> IO Bool
fetch tupleIndex countTuples =
modifyMVar tupleIndex (\index -> return (index+1,index < countTuples1))
getColValue :: PGresult -> MVar Int -> Int -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue pRes tupleIndex countTuples colNumber fieldDef f = do
index <- readMVar tupleIndex
when (index >= countTuples) (throwDyn SqlNoData)
isnull <- pqGetisnull pRes index colNumber
if isnull == 1
then f fieldDef nullPtr 0
else do
pStr <- pqGetvalue pRes index colNumber
strLen <- strlen pStr
f fieldDef pStr strLen
newCStringElseNullPtr :: Maybe String -> IO CString
newCStringElseNullPtr Nothing =
return nullPtr
newCStringElseNullPtr (Just string) =
newCString string