{-| Implementation of an SQLite3-based event store. -}
module Data.CQRS.EventStore.Backend.Sqlite3Utils
       ( sourceQuery
       , execSql
       , withTransaction
       ) where

import           Control.Exception (catch, bracket, onException, SomeException)
import           Control.Monad (liftM, when)
import           Data.Conduit (Source)
import qualified Data.Conduit as C
import           Data.IORef (newIORef, readIORef, writeIORef)
import qualified Database.SQLite3 as SQL
import           Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..))
import           Prelude hiding (catch)

beginTransaction :: Database -> IO ()
beginTransaction database = execSql database "BEGIN TRANSACTION;" []

commitTransaction :: Database -> IO ()
commitTransaction database = execSql database "COMMIT TRANSACTION;" []

rollbackTransaction :: Database -> IO ()
rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" []

withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a
withSqlStatement database sql parameters action =
  bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do
    SQL.bind statement parameters
    action statement

-- | Execute an SQL statement for which no result is expected.
execSql :: Database -> String -> [SQLData] -> IO ()
execSql database sql parameters =
  withSqlStatement database sql parameters $ \stmt -> do
    _ <- SQL.step stmt
    return ()

data State = Unbound
           | Bound
             deriving (Eq)

sourceQuery :: Database -> String -> [SQLData] -> Source IO [SQLData]
sourceQuery database sql parameters =
  C.sourceIO
  (do
      stateRef <- newIORef Unbound
      stmt <- SQL.prepare database sql
      return (stateRef, stmt))
  (\(_,stmt) -> SQL.finalize stmt)
  (\(stateRef,stmt) -> do
      -- Bind parameters if necessary.
      state <- readIORef stateRef
      when (state == Unbound) $ do
        SQL.bind stmt parameters
        writeIORef stateRef $ Bound
      -- Fetch results.
      nextResult <- SQL.step stmt
      case nextResult of
        Done -> return C.Closed
        Row -> liftM C.Open $ SQL.columns stmt)

-- | Execute an IO action with an active transaction.
withTransaction :: Database -> IO a -> IO a
withTransaction database action = do
  beginTransaction database
  onException runAction tryRollback
  where
    runAction = do
      r <- action
      commitTransaction database
      return r

    tryRollback =
      -- Try rollback while discarding exception; we want to preserve
      -- original exception.
      catch (rollbackTransaction database) (\(_::SomeException) -> return ())