{-| Module      :  Database.HSQL.PostgreSQL
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  ka2_mail@yahoo.com
    Stability   :  provisional
    Portability :  portable

    The module provides an interface to PostgreSQL database
-}
module Database.HSQL.PostgreSQL(connect, 
                                connectWithOptions, 
                                module Database.HSQL) where

import Foreign(free)
import Foreign.C(newCString,peekCString)
import Control.Exception (throw)
import Control.Monad(unless)
import Control.Concurrent.MVar(newMVar)

import Database.HSQL hiding(query,execute)
import Database.HSQL.Types(Connection(..),Statement(stmtGetCol))
import DB.HSQL.PG.Functions
import DB.HSQL.PG.Type(mkSqlType)
import DB.HSQL.PG.Core(query,execute,newCStringElseNullPtr)
import DB.HSQL.PG.Status(connectionOk)
import DB.HSQL.PG.Sql(sqlAllTableNames,sqlAllFieldDefsForTableName)

------------------------------------------------------------------------------
-- Connect/Disconnect
------------------------------------------------------------------------------
-- | Makes a new connection to the database server
connect :: String   -- ^ Server name : port nr
        -> String   -- ^ Database name
        -> String   -- ^ User identifier
        -> String   -- ^ Authentication string (password)
        -> IO Connection
connect server database user authentication = do
  let (serverAddress,portInput)= break (==':') server
      port= if length portInput < 2 
            then Nothing 
            else Just (tail portInput)
  connectWithOptions serverAddress 
                     port
                     Nothing 
                     Nothing 
                     database 
                     user 
                     authentication


-- | Makes a new connection to the database server,
-- with specification of port, options & tty
connectWithOptions :: String   -- ^ Server name
                   -> Maybe String   -- ^ Port number
                   -> Maybe String   -- ^ Options
                   -> Maybe String   -- ^ TTY
                   -> String   -- ^ Database name
                   -> String   -- ^ User identifier
                   -> String   -- ^ Authentication string (password)
                   -> IO Connection
connectWithOptions server port options tty database user authentication = do
  pServer <- newCString server
  pPort <- newCStringElseNullPtr port
  pOptions <- newCStringElseNullPtr options
  pTty <- newCStringElseNullPtr tty
  pDatabase <- newCString database
  pUser <- newCString user
  pAuthentication <- newCString authentication
  pConn <- pqSetdbLogin pServer pPort pOptions pTty 
                        pDatabase pUser pAuthentication
  free pServer
  free pPort
  free pOptions
  free pTty
  free pUser
  free pAuthentication
  status <- pqStatus pConn
  unless (status == connectionOk)  (do
    errMsg <- pqErrorMessage pConn >>= peekCString
    pqFinish pConn
    throw (SqlError { seState="C"
                    , seNativeError=fromIntegral status
                    , seErrorMsg=errMsg }))
  refFalse <- newMVar False
  let connection = 
          Connection { connDisconnect = pqFinish pConn
                     , connExecute = execute pConn
                     , connQuery = query connection pConn
                     , connTables = tables connection pConn
                     , connDescribe = describe connection pConn
                     , connBeginTransaction = execute pConn "BEGIN"
                     , connCommitTransaction = execute pConn "COMMIT"
                     , connRollbackTransaction = execute pConn "ROLLBACK"
                     , connClosed = refFalse	}
  return connection
  where
    getFieldValue stmt colNumber fieldDef v = do
      mb_v <- stmtGetCol stmt colNumber fieldDef fromSqlCStringLen
      return (case mb_v of { Nothing -> v; Just a -> a })

    tables :: Connection -> PGconn -> IO [String]
    tables connection pConn = do
      stmt <- query connection pConn sqlAllTableNames
      collectRows (\s -> 
          getFieldValue  s 0 ("relname", SqlVarChar 0, False) "") stmt

    describe :: Connection -> PGconn -> String -> IO [FieldDef]
    describe connection pConn table = do
      stmt <- query connection pConn
		    (sqlAllFieldDefsForTableName table)
      collectRows getColumnInfo stmt
      where
	getColumnInfo stmt = do
	  column_name <- 
              getFieldValue stmt 0 ("attname", SqlVarChar 0, False) ""
	  data_type <- getFieldValue stmt 1 ("atttypid", SqlInteger, False) 0
	  type_mod <- getFieldValue stmt 2 ("atttypmod", SqlInteger, False) 0
	  notnull <- getFieldValue stmt 3 ("attnotnull", SqlBit, False) False
	  let sqlType = mkSqlType (fromIntegral (data_type :: Int)) 
                                  (fromIntegral (type_mod :: Int))
	  return (column_name, sqlType, not notnull)