{-# LANGUAGE ScopedTypeVariables #-}
{-| Module      :  Database.HSQL.SQLite3
    Copyright   :  (c) Krasimir Angelov 2005
    License     :  BSD-style

    Maintainer  :  kr.angelov@gmail.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to SQLite3
-}
module Database.HSQL.SQLite3(connect
                            ,module Database.HSQL) where

import Foreign(Ptr,alloca,peek,nullPtr,peekElemOff,nullFunPtr)
import Foreign.C(CString,CInt,peekCString,withCString)
import System.IO(IOMode(..))
import Control.Monad(when)
import Control.Exception(throw)
import Control.Concurrent.MVar(newMVar,modifyMVar,readMVar)

import Database.HSQL
import Database.HSQL.Types(Connection(..),Statement(..))

import DB.HSQL.SQLite3.Functions

------------------------------------------------------------------------------
-- Connect
------------------------------------------------------------------------------
-- |
connect :: FilePath -> IOMode -> IO Connection
connect fpath mode =
    alloca $ \psqlite ->
	withCString fpath $ \pFPath -> do
	  res <- sqlite3_open pFPath psqlite
	  sqlite <- peek psqlite
	  when (res /= sqliteOk) $ do
	    pMsg <- sqlite3_errmsg sqlite
	    msg <- peekCString pMsg
	    throw SqlError { seState = "C"
			   , seNativeError = 0
			   , seErrorMsg = msg }
	  refFalse <- newMVar False
	  let connection 
                = Connection { connDisconnect = sqlite3_close sqlite
			     , connClosed     = refFalse
			     , connExecute    = sqlite3Execute sqlite
			     , connQuery      = sqlite3Query connection sqlite
			     , connTables     = sqlite3Tables connection sqlite
			     , connDescribe   = 
                                 sqlite3Describe connection sqlite
			     , connBeginTransaction = 
                                 sqlite3Execute sqlite "BEGIN TRANSACTION"
			     , connCommitTransaction = 
                                 sqlite3Execute sqlite "COMMIT TRANSACTION"
			     , connRollbackTransaction = 
                                 sqlite3Execute sqlite "ROLLBACK TRANSACTION" }
	  return connection
    where oflags1 = case mode of
	    ReadMode      -> oRdOnly
	    WriteMode     -> oWrOnly
	    ReadWriteMode -> oRdWr
	    AppendMode    -> oAppend

-- |
sqlite3Tables :: Connection -> SQLite3 -> IO [String]
sqlite3Tables connection sqlite = do
  stmt <- sqlite3Query connection sqlite "SELECT tbl_name FROM sqlite_master"
  collectRows (\stmt -> getFieldValue stmt "tbl_name") stmt

-- |
sqlite3Describe :: Connection -> SQLite3 -> String -> IO [FieldDef]
sqlite3Describe connection sqlite table = do
  stmt <- sqlite3Query connection sqlite ("pragma table_info("++table++")")
  collectRows getRow stmt

-- |
sqlite3Query :: Connection -> SQLite3 -> String -> IO Statement
sqlite3Query connection sqlite query = do
  withCString query $ \pQuery -> do
  alloca $ \ppResult -> do
  alloca $ \pnRow -> do
  alloca $ \pnColumn -> do
  alloca $ \ppMsg -> do
    res <- sqlite3_get_table sqlite pQuery ppResult pnRow pnColumn ppMsg
    handleSqlResult res ppMsg
    pResult <- peek ppResult
    rows    <- fmap fromIntegral (peek pnRow)
    columns <- fmap fromIntegral (peek pnColumn)
    defs <- getFieldDefs pResult 0 columns
    refFalse <- newMVar False
    refIndex <- newMVar 0
    return Statement { stmtConn   = connection
		     , stmtClose  = sqlite3_free_table pResult
		     , stmtFetch  = sqlite3Fetch refIndex rows
		     , stmtGetCol = getColValue pResult refIndex columns rows
		     , stmtFields = defs
		     , stmtClosed = refFalse }

-- |
getRow stmt = do
  name <- getFieldValue stmt "name"
  notnull <- getFieldValue stmt "notnull"
  return (name, SqlText, notnull=="0")

-- |
getFieldDefs :: Ptr CString -> Int -> Int -> IO [FieldDef]
getFieldDefs pResult index count
    | index >= count = return []
    | otherwise = do
        name <- peekElemOff pResult index >>= peekCString
	defs <- getFieldDefs pResult (index+1) count
	return ((name,SqlText,True):defs)

-- |
sqlite3Fetch tupleIndex countTuples =
    modifyMVar tupleIndex 
               (\index -> return (index+1,index < countTuples))

-- |
getColValue pResult refIndex columns rows colNumber fieldDef f = do
  index <- readMVar refIndex
  when (index > rows) (throw SqlNoData)
  pStr <- peekElemOff pResult (columns*index+colNumber)
  if pStr == nullPtr
    then f fieldDef pStr 0
    else do strLen <- strlen pStr
	    f fieldDef pStr (fromIntegral strLen)

-- |
sqlite3Execute :: SQLite3 -> String -> IO ()
sqlite3Execute sqlite query =
    withCString query $ \pQuery -> do
      alloca $ \ppMsg -> do
	res <- sqlite3_exec sqlite pQuery nullFunPtr nullPtr ppMsg
	handleSqlResult res ppMsg


------------------------------------------------------------------------------
-- routines for handling exceptions
------------------------------------------------------------------------------
-- |
handleSqlResult :: CInt -> Ptr CString -> IO ()
handleSqlResult res ppMsg
    | fromIntegral res == sqliteOk = return ()
    | otherwise = do
        pMsg <- peek ppMsg
	msg <- peekCString pMsg
	sqlite3_free pMsg
	throw (SqlError "E" (fromIntegral res) msg)