{-| Implementation of an SQLite3-based event store. -}
module Data.CQRS.EventStore.Backend.Sqlite3Utils
       ( enumQueryResult
       , execSql
       , withTransaction
       ) where

import           Control.Exception (catch, bracket, finally, onException, SomeException)
import           Data.Enumerator (Enumerator, Iteratee(..), tryIO, continue, (>>==), Stream(..))
import           Data.Enumerator.Internal (checkContinue0)
import qualified Database.SQLite3 as SQL
import           Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..))
import           Prelude hiding (catch)

beginTransaction :: Database -> IO ()
beginTransaction database = execSql database "BEGIN TRANSACTION;" []

commitTransaction :: Database -> IO ()
commitTransaction database = execSql database "COMMIT TRANSACTION;" []

rollbackTransaction :: Database -> IO ()
rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" []

withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a
withSqlStatement database sql parameters action =
  bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do
    SQL.bind statement parameters
    action statement

-- | Execute an SQL statement for which no result is expected.
execSql :: Database -> String -> [SQLData] -> IO ()
execSql database sql parameters =
  withSqlStatement database sql parameters $ \stmt -> do
    _ <- SQL.step stmt
    return ()

enumQueryResults :: Statement -> Enumerator [SQLData] IO b
enumQueryResults stmt = checkContinue0 $ \loop k ->
  do
    nextResult <- tryIO $ SQL.step stmt
    case nextResult of
      Done -> continue k
      Row -> do
        cols <- tryIO $ SQL.columns stmt
        k (Chunks [cols]) >>== loop

-- | Perform a query and enumerate the results. Each result is a list
-- of returned columns.
enumQueryResult :: Database -> String -> [SQLData] -> Enumerator [SQLData] IO b
enumQueryResult database sql parameters step = do
  stmt <- tryIO $ SQL.prepare database sql
  Iteratee $ finally
    (do
        SQL.bind stmt parameters
        runIteratee $ enumQueryResults stmt step)
    (SQL.finalize stmt)

-- | Execute an IO action with an active transaction.
withTransaction :: Database -> IO a -> IO a
withTransaction database action = do
  beginTransaction database
  onException runAction tryRollback
  where
    runAction = do
      r <- action
      commitTransaction database
      return r

    tryRollback =
      -- Try rollback while discarding exception; we want to preserve
      -- original exception.
      catch (rollbackTransaction database) (\(_::SomeException) -> return ())