{-# LINE 1 "Database/HSQL/PostgreSQL.hsc" #-}
-----------------------------------------------------------------------------------------
{-# LINE 2 "Database/HSQL/PostgreSQL.hsc" #-}
{-| Module      :  Database.HSQL.PostgreSQL
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  ka2_mail@yahoo.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to PostgreSQL database
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.PostgreSQL(connect, 
                                connectWithOptions, 
                                module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Char
import Foreign
import Foreign.C
import Control.OldException (throwDyn, catchDyn, dynExceptions, Exception(..))
import Control.Monad(when,unless,mplus)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
import Numeric


{-# LINE 34 "Database/HSQL/PostgreSQL.hsc" #-}

{-# LINE 35 "Database/HSQL/PostgreSQL.hsc" #-}

{-# LINE 36 "Database/HSQL/PostgreSQL.hsc" #-}

{-# LINE 37 "Database/HSQL/PostgreSQL.hsc" #-}

type PGconn = Ptr ()
type PGresult = Ptr ()
type ConnStatusType = Word32
{-# LINE 41 "Database/HSQL/PostgreSQL.hsc" #-}
type ExecStatusType = Word32
{-# LINE 42 "Database/HSQL/PostgreSQL.hsc" #-}
type Oid = Word32
{-# LINE 43 "Database/HSQL/PostgreSQL.hsc" #-}

{-| Refer to PostgreSQL manual, chapter 30, `libpq - C library' 
    (e.g. http://www.postgresql.org/docs/8.3/interactive/libpq.html)
-}
foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn
foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType
foreign import ccall "libpq-fe.h PQerrorMessage"  pqErrorMessage :: PGconn -> IO CString
foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO ()
foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult
foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType
foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString
foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString
foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString
foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid
foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int
foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int
foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString
foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int

foreign import ccall "strlen" strlen :: CString -> IO Int

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

-- | Makes a new connection to the database server
connect :: String   -- ^ Server name : port nr
        -> String   -- ^ Database name
        -> String   -- ^ User identifier
        -> String   -- ^ Authentication string (password)
        -> IO Connection
connect server database user authentication = do
    let (serverAddress,portInput)= break (==':') server
        port= if length portInput < 2 
              then Nothing 
              else Just (tail portInput)
    connectWithOptions serverAddress 
                       port
                       Nothing 
                       Nothing 
                       database 
                       user 
                       authentication


-- | Makes a new connection to the database server, with specification of port, options & tty
connectWithOptions :: String   -- ^ Server name
             -> Maybe String   -- ^ Port number
             -> Maybe String   -- ^ Options
             -> Maybe String   -- ^ TTY
             -> String   -- ^ Database name
             -> String   -- ^ User identifier
             -> String   -- ^ Authentication string (password)
             -> IO Connection
connectWithOptions server port options tty database user authentication = do
	pServer <- newCString server
	pPort <- newCStringElseNullPtr port
	pOptions <- newCStringElseNullPtr options
	pTty <- newCStringElseNullPtr tty
	pDatabase <- newCString database
	pUser <- newCString user
	pAuthentication <- newCString authentication
	pConn <- pqSetdbLogin pServer pPort pOptions pTty pDatabase pUser pAuthentication
	free pServer
	free pPort
	free pOptions
	free pTty
	free pUser
	free pAuthentication
	status <- pqStatus pConn
	unless (status == (0))  (do
{-# LINE 116 "Database/HSQL/PostgreSQL.hsc" #-}
		errMsg <- pqErrorMessage pConn >>= peekCString
		pqFinish pConn
		throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
	refFalse <- newMVar False
	let connection = Connection
		{ connDisconnect = pqFinish pConn
		, connExecute = execute pConn
		, connQuery = query connection pConn
		, connTables = tables connection pConn
		, connDescribe = describe connection pConn
		, connBeginTransaction = execute pConn "begin"
		, connCommitTransaction = execute pConn "commit"
		, connRollbackTransaction = execute pConn "rollback"
		, connClosed = refFalse
		}
	return connection
	where
		execute :: PGconn -> String -> IO ()
		execute pConn sqlExpr = do
			pRes <- withCString sqlExpr (pqExec pConn)
			when (pRes==nullPtr) (do
				errMsg <- pqErrorMessage pConn >>= peekCString
				throwDyn (SqlError {seState="E", seNativeError=(7), seErrorMsg=errMsg}))
{-# LINE 139 "Database/HSQL/PostgreSQL.hsc" #-}
			status <- pqResultStatus pRes
			unless (status == (1) || status == (2)) (do
{-# LINE 141 "Database/HSQL/PostgreSQL.hsc" #-}
				errMsg <- pqResultErrorMessage pRes >>= peekCString
				throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
			return ()

		query :: Connection -> PGconn -> String -> IO Statement
		query conn pConn query = do
			pRes <- withCString query (pqExec pConn)
			when (pRes==nullPtr) (do
				errMsg <- pqErrorMessage pConn >>= peekCString
				throwDyn (SqlError {seState="E", seNativeError=(7), seErrorMsg=errMsg}))
{-# LINE 151 "Database/HSQL/PostgreSQL.hsc" #-}
			status <- pqResultStatus pRes
			unless (status == (1) || status == (2)) (do
{-# LINE 153 "Database/HSQL/PostgreSQL.hsc" #-}
				errMsg <- pqResultErrorMessage pRes >>= peekCString
				throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
			defs <- if status ==  (2) then pgNFields pRes >>= getFieldDefs pRes 0 else return []
{-# LINE 156 "Database/HSQL/PostgreSQL.hsc" #-}
			countTuples <- pqNTuples pRes;
			tupleIndex <- newMVar (-1)
			refFalse <- newMVar False
			return (Statement
			              { stmtConn   = conn
			              , stmtClose  = return ()
			              , stmtFetch  = fetch tupleIndex countTuples
			              , stmtGetCol = getColValue pRes tupleIndex countTuples
			              , stmtFields = defs
			              , stmtClosed = refFalse
			              })
			where
				getFieldDefs pRes i n
					| i >= n = return []
					| otherwise = do
						name <- pgFName pRes i	>>= peekCString
						dataType <- pqFType pRes i
						modifier <- pqFMod pRes i
						defs <- getFieldDefs pRes (i+1) n
						return ((name,mkSqlType dataType modifier,True):defs)

		mkSqlType :: Oid -> Int -> SqlType
		mkSqlType (1042)    size = SqlChar (size-4)
{-# LINE 179 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1043)   size = SqlVarChar (size-4)
{-# LINE 180 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (19)      size = SqlVarChar 31
{-# LINE 181 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (25)      size = SqlText
{-# LINE 182 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1700)   size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000)
{-# LINE 183 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (21)      size = SqlSmallInt
{-# LINE 184 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (23)      size = SqlInteger
{-# LINE 185 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (700)    size = SqlReal
{-# LINE 186 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (701)    size = SqlDouble
{-# LINE 187 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (16)      size = SqlBit
{-# LINE 188 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1560)       size = SqlBinary size
{-# LINE 189 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1562)    size = SqlVarBinary size
{-# LINE 190 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (17)     size = SqlTinyInt
{-# LINE 191 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (20)      size = SqlBigInt
{-# LINE 192 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1082)      size = SqlDate
{-# LINE 193 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1083)      size = SqlTime
{-# LINE 194 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1266)    size = SqlTimeTZ
{-# LINE 195 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (702)   size = SqlAbsTime
{-# LINE 196 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (703)   size = SqlRelTime
{-# LINE 197 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1186)  size = SqlTimeInterval
{-# LINE 198 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (704) size = SqlAbsTimeInterval
{-# LINE 199 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1114)	size = SqlDateTime
{-# LINE 200 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (1184)	size = SqlDateTimeTZ
{-# LINE 201 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (790)      size = SqlMoney
{-# LINE 202 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (869)      size = SqlINetAddr
{-# LINE 203 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (829)          size = SqlMacAddr		-- hack
{-# LINE 204 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (650)      size = SqlCIDRAddr
{-# LINE 205 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (600)     size = SqlPoint
{-# LINE 206 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (601)      size = SqlLSeg
{-# LINE 207 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (602)      size = SqlPath
{-# LINE 208 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (603)       size = SqlBox
{-# LINE 209 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (604)   size = SqlPolygon
{-# LINE 210 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (628)      size = SqlLine
{-# LINE 211 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType (718)    size = SqlCircle
{-# LINE 212 "Database/HSQL/PostgreSQL.hsc" #-}
		mkSqlType tp   size = SqlUnknown (fromIntegral tp)

		getFieldValue stmt colNumber fieldDef v = do
			mb_v <- stmtGetCol stmt colNumber fieldDef fromSqlCStringLen
			return (case mb_v of { Nothing -> v; Just a -> a })

		tables :: Connection -> PGconn -> IO [String]
		tables connection pConn = do
			stmt <- query connection pConn "select relname from pg_class where relkind='r' and relname !~ '^pg_'"
			collectRows (\s -> getFieldValue  s 0 ("relname", SqlVarChar 0, False) "") stmt

		describe :: Connection -> PGconn -> String -> IO [FieldDef]
		describe connection pConn table = do
			stmt <- query connection pConn
					("select attname, atttypid, atttypmod, attnotnull " ++
					 "from pg_attribute as cols join pg_class as ts on cols.attrelid=ts.oid " ++
					 "where cols.attnum > 0 and ts.relname="++toSqlValue table++
					 " and cols.attisdropped = False ")

			collectRows getColumnInfo stmt
			where
				getColumnInfo stmt = do
					column_name <- getFieldValue stmt 0 ("attname", SqlVarChar 0, False) ""
					data_type <- getFieldValue stmt 1 ("atttypid", SqlInteger, False) 0
					type_mod <- getFieldValue stmt 2 ("atttypmod", SqlInteger, False) 0
					notnull <- getFieldValue stmt 3 ("attnotnull", SqlBit, False) False
					let sqlType = mkSqlType (fromIntegral (data_type :: Int)) (fromIntegral (type_mod :: Int))
					return (column_name, sqlType, not notnull)

		fetch :: MVar Int -> Int -> IO Bool
		fetch tupleIndex countTuples =
			modifyMVar tupleIndex (\index -> return (index+1,index < countTuples-1))

		getColValue :: PGresult -> MVar Int -> Int -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
		getColValue pRes tupleIndex countTuples colNumber fieldDef f = do
			index <- readMVar tupleIndex
			when (index >= countTuples) (throwDyn SqlNoData)
			isnull <- pqGetisnull pRes index colNumber
			if isnull == 1
				then f fieldDef nullPtr 0
				else do
					pStr <- pqGetvalue pRes index colNumber
					strLen <- strlen pStr
					f fieldDef pStr strLen


-- | Convert string by newCString, if provided, else return of nullPtr
newCStringElseNullPtr :: Maybe String -> IO CString
newCStringElseNullPtr Nothing =
    return nullPtr
newCStringElseNullPtr (Just string) =
    newCString string