{-| Implementation of an SQLite3-based event store. -}
module Data.CQRS.EventStore.Sqlite3
       ( openSqliteEventStore
       , closeEventStore -- Re-export for convenience
       ) where

import           Control.Exception (catch, bracket, onException, SomeException)
import           Control.Monad (when, forM_)
import           Data.Binary (encode, decode, Binary)
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as BSL
import           Data.CQRS.EventStore
import           Data.CQRS.Event (Event)
import           Data.CQRS.GUID
import qualified Database.SQLite3 as SQL
import           Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..))
import           Prelude hiding (catch)

createEventsSql :: String
createEventsSql = "CREATE TABLE IF NOT EXISTS events ( guid BLOB , ev_data BLOB , version INTEGER , PRIMARY KEY (guid, version) );"

selectEventsSql :: String
selectEventsSql = "SELECT version, ev_data FROM events WHERE guid = ?;"

insertEventSql :: String
insertEventSql = "INSERT INTO events ( guid, version, ev_data ) VALUES (?, ?, ?);"

createAggregateVersionsSql :: String
createAggregateVersionsSql = "CREATE TABLE IF NOT EXISTS versions ( guid BLOB PRIMARY KEY , version INTEGER );"

getCurrentVersionSql :: String
getCurrentVersionSql = "SELECT version FROM versions WHERE guid = ?;"

updateCurrentVersionSql :: String
updateCurrentVersionSql = "INSERT OR REPLACE INTO versions ( guid, version ) VALUES (?,?);"

beginTransaction :: Database -> IO ()
beginTransaction database = execSql database "BEGIN TRANSACTION;" []

commitTransaction :: Database -> IO ()
commitTransaction database = execSql database "COMMIT TRANSACTION;" []

rollbackTransaction :: Database -> IO ()
rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" []

toSQLBlob :: (Binary a) => a -> SQLData
toSQLBlob a = SQLBlob $ B8.concat $ BSL.toChunks $ encode a

withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a
withSqlStatement database sql parameters action =
  bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do
    SQL.bind statement parameters
    action statement

execSql :: Database -> String -> [SQLData] -> IO ()
execSql database sql parameters =
  withSqlStatement database sql parameters $ \stmt -> do
    _ <- SQL.step stmt
    return ()

querySql :: Database -> String -> [SQLData] -> ([SQLData] -> IO a) -> IO [a]
querySql database sql parameters reader =
  withSqlStatement database sql parameters go
  where
    go statement = loop [ ]
      where
        loop acc = do
          res <- SQL.step statement
          case res of
            Done -> return $ reverse acc
            Row -> do
              cols <- SQL.columns statement
              a <- reader cols
              loop (a:acc)

badQueryResult :: GUID a -> [SQLData] -> IO b
badQueryResult guid columns =
  fail $ concat ["Invalid query result for ", show guid, ": ", show columns]

versionConflict :: (Show a, Show b) => a -> b -> IO c
versionConflict ov cv =
  fail $ concat [ "Version conflict detected (expected ", show ov
                , ", saw ", show cv, ")"
                ]

storeEvents_ :: Database -> (Event e a, Binary e) => GUID a -> Int -> [e] -> IO ()
storeEvents_ database guid originatingVersion events = do
  -- Get the current version number of the aggregate.
  versions <- querySql database getCurrentVersionSql [toSQLBlob guid] $ \columns ->
    case columns of
      [ SQLInteger v ] -> return v
      _ -> badQueryResult guid columns
  let curVer = maximum (0 : versions)

  -- Sanity check current version number.
  when (fromIntegral curVer /= originatingVersion) $
    versionConflict originatingVersion curVer

  -- Update de-normalized version number.
  execSql database updateCurrentVersionSql
    [ toSQLBlob guid
    , SQLInteger $ fromIntegral $ originatingVersion + length events
    ]

  -- Store the supplied events.
  forM_ (zip [1 + originatingVersion..] events) $ \(v,e) -> do
    execSql database insertEventSql
      [ toSQLBlob guid
      , SQLInteger $ fromIntegral v
      , toSQLBlob e
      ]

retrieveEvents_ :: (Event e a, Binary e) => Database -> GUID a -> IO (Int,[e])
retrieveEvents_ database guid = do
  -- Find events.
  results <- querySql database selectEventsSql [toSQLBlob guid] $ \columns -> do
    case columns of
      [SQLInteger version, SQLBlob eventData] ->
        return (version, decode $ BSL.fromChunks [eventData])
      _ ->
        badQueryResult guid columns

  -- Find the max version number.
  let maxVersion = maximum $ (:) 0 $ map fst results
  return (fromIntegral maxVersion, map snd results)

withTransaction_ :: forall a . Database -> IO a -> IO a
withTransaction_ database action = do
  beginTransaction database
  onException runAction tryRollback
  where
    runAction = do
      r <- action
      commitTransaction database
      return r

    tryRollback =
      -- Try rollback while discarding exception; we want to preserve
      -- original exception.
      catch (rollbackTransaction database) (\(_::SomeException) -> return ())

-- | Open an SQLite3-based event store using the named SQLite database file.
-- The database file is created if it does not exist.
openSqliteEventStore :: String -> IO EventStore
openSqliteEventStore databaseFileName = do
  -- Create the database.
  database <- SQL.open databaseFileName
  -- Set up tables.
  execSql database createEventsSql []
  execSql database createAggregateVersionsSql []
  -- Return event store.
  return $ EventStore { storeEvents = storeEvents_ database
                      , retrieveEvents = retrieveEvents_ database
                      , withTransaction = withTransaction_ database
                      , closeEventStore = SQL.close database
                      }