module Database.Sqlite3.Middle where

import Bindings.Sqlite3
import qualified Data.ByteString as B

import Codec.Binary.UTF8.String
import Control.Monad.Error
import Foreign.C hiding (newCString)
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import Foreign.Ptr
import Foreign.Storable

type Db a = (DbError e, MonadError e m, MonadDb m) => m a

class DbError e where
  makeErr :: CInt -> e
  castErr :: e -> Maybe CInt

class MonadIO m => MonadDb m where
  getDb     :: m (Ptr C'sqlite3)
  putDb     :: Ptr C'sqlite3 -> m ()
  cleanDb   :: m ()
  isDbReady :: m Bool
  getSt     :: m (Ptr C'sqlite3_stmt)
  putSt     :: Ptr C'sqlite3_stmt -> m ()
  cleanSt   :: m ()
  isStReady :: m Bool

open :: String -> Db ()
open path = do
  cpath <- liftIO $ newCString path
  db'   <- liftIO $ malloc
  rc    <- liftIO $ c'sqlite3_open cpath db'
  db    <- liftIO $ peek db'
  putDb db
  liftIO $ free db'
  err rc

close :: Db ()
close = do
  db <- getDb
  rc  <- liftIO $ c'sqlite3_close db
  cleanDb
  err rc

prepare :: String -> Db ()
prepare sql = do
  cstr <- liftIO $ newCString sql
  let len = toEnum $ length sql
  st'  <- liftIO $ malloc
  db   <- getDb
  rc   <- liftIO $ c'sqlite3_prepare db cstr len st' nullPtr
  st   <- liftIO $ peek st'
  putSt st
  liftIO $ free st'
  err rc

finalize :: Db ()
finalize = do
  st <- getSt
  rc <- liftIO $ c'sqlite3_finalize st
  cleanSt
  err rc

reset :: Db ()
reset = do
  st <- getSt
  rc <- liftIO $ c'sqlite3_reset st
  err rc

step :: Db Bool
step = do
  st <- getSt
  rc <- liftIO $ c'sqlite3_step st
  case rc of
    100 -> return False
    101 -> return True
    _   -> throwError (makeErr rc)

bindInt :: Int -> Int -> Db ()
bindInt num val = do
  st <- getSt
  rc <- liftIO $ c'sqlite3_bind_int st (toEnum num) (toEnum val)
  err rc

bindDouble :: Int -> Double -> Db ()
bindDouble num val = do
  st <- getSt
  rc <- liftIO $ c'sqlite3_bind_double st (toEnum num) (realToFrac val)
  err rc

bindText :: Int -> String -> Db ()
bindText num val = do
  cstr <- liftIO $ newCString val
  let len = toEnum $ length val
  st   <- getSt
  rc   <- liftIO $ c'sqlite3_bind_text st (toEnum num) cstr len nullFunPtr
  err rc

columnBytes :: Int -> Db Int
columnBytes num = do
  st <- getSt
  cint <- liftIO $ c'sqlite3_column_bytes st (toEnum num)
  return (fromEnum cint)

columnCount :: Db Int
columnCount = do
  st <- getSt
  cint <- liftIO $ c'sqlite3_column_count st
  return (fromEnum cint)

columnInt :: Int -> Db Int
columnInt num = do
  st <- getSt
  cint <- liftIO $ c'sqlite3_column_int st (toEnum num)
  return (fromEnum cint)

columnType :: Int -> Db Int
columnType num = do
  st <- getSt
  cint <- liftIO $ c'sqlite3_column_type st (toEnum num)
  return (fromEnum cint)

columnDouble :: Int -> Db Double
columnDouble num = do
  st <- getSt
  x <- liftIO $ c'sqlite3_column_double st (toEnum num)
  return (realToFrac x)

columnBlob :: Int -> Db B.ByteString
columnBlob num = do
  st <- getSt
  by <- columnBytes num
  bl <- liftIO $ c'sqlite3_column_blob st (toEnum num)
  liftIO $ B.packCStringLen (castPtr bl,by)

errorHook x = catchError x hook where
 hook e = do
   ignore finalize
   ignore close
   throwError e

ignore x = catchError (x >> return ()) (const $ return ())

newCString :: String -> IO CString
newCString = newArray0 0 . map fromIntegral . encode

err :: CInt -> Db ()
err 0 = return ()
err x = throwError (makeErr x)