{-| Implementation of an SQLite3-based event store. -}
module Data.CQRS.EventStore.Sqlite3
       ( openSqliteEventStore
       , closeEventStore -- Re-export for convenience
       ) where

import           Control.Exception (catch, bracket, onException, SomeException)
import           Control.Monad (when, forM_)
import           Data.ByteString (ByteString)
import           Data.CQRS.EventStore
import           Data.CQRS.GUID (GUID)
import qualified Data.CQRS.GUID as G
import qualified Database.SQLite3 as SQL
import           Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..))
import           Prelude hiding (catch)

-- Convenience class for converting values to SQLData.
class ToSQLData a where
  toSQLData :: a -> SQLData

instance ToSQLData (GUID a) where
  toSQLData = SQLBlob . G.toByteString

instance ToSQLData Int where
  toSQLData = SQLInteger . fromIntegral

instance ToSQLData ByteString where
  toSQLData = SQLBlob

-- SQL
createEventsSql :: String
createEventsSql = "CREATE TABLE IF NOT EXISTS events ( guid BLOB , ev_data BLOB , version INTEGER , gversion INTEGER, PRIMARY KEY (guid, version) );"

selectEventsSql :: String
selectEventsSql = "SELECT version, ev_data FROM events WHERE guid = ? AND version > ? ORDER BY version ASC;"

selectAllEventsSql :: String
selectAllEventsSql = "SELECT version, ev_data FROM events WHERE gversion >= ? AND gversion < ? ORDER BY gversion ASC;"

insertEventSql :: String
insertEventSql = "INSERT INTO events ( guid, version, ev_data, gversion ) VALUES (?, ?, ?, COALESCE((SELECT MAX(gversion) FROM events ), 0) + 1);"

createAggregateVersionsSql :: String
createAggregateVersionsSql = "CREATE TABLE IF NOT EXISTS versions ( guid BLOB PRIMARY KEY , version INTEGER );"

getCurrentVersionSql :: String
getCurrentVersionSql = "SELECT version FROM versions WHERE guid = ?;"

updateCurrentVersionSql :: String
updateCurrentVersionSql = "INSERT OR REPLACE INTO versions ( guid, version ) VALUES (?,?);"

createSnapshotSql :: String
createSnapshotSql = "CREATE TABLE IF NOT EXISTS snapshots ( guid BLOB PRIMARY KEY , data BLOB , version INTEGER );"

writeSnapshotSql :: String
writeSnapshotSql = "INSERT OR REPLACE INTO snapshots ( guid , data, version ) VALUES ( ?, ?, ? );"

selectSnapshotSql :: String
selectSnapshotSql = "SELECT data, version FROM snapshots WHERE guid = ?;"

beginTransaction :: Database -> IO ()
beginTransaction database = execSql database "BEGIN TRANSACTION;" []

commitTransaction :: Database -> IO ()
commitTransaction database = execSql database "COMMIT TRANSACTION;" []

rollbackTransaction :: Database -> IO ()
rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" []

withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a
withSqlStatement database sql parameters action =
  bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do
    SQL.bind statement parameters
    action statement

execSql :: Database -> String -> [SQLData] -> IO ()
execSql database sql parameters =
  withSqlStatement database sql parameters $ \stmt -> do
    _ <- SQL.step stmt
    return ()

querySql :: Database -> String -> [SQLData] -> a -> ([SQLData] -> a -> IO a) -> IO a
querySql database sql parameters a0 reader =
  withSqlStatement database sql parameters go
  where
    go statement = loop a0
      where
        loop acc = do
          res <- SQL.step statement
          case res of
            Done -> return acc
            Row -> do
              cols <- SQL.columns statement
              acc' <- reader cols acc
              loop acc'

badQueryResult :: [String] -> [SQLData] -> IO b
badQueryResult params columns =
  fail $ concat ["Invalid query result shape. Params: ", show params, ". Result columns: ", show columns]

versionConflict :: (Show a, Show b) => a -> b -> IO c
versionConflict ov cv =
  fail $ concat [ "Version conflict detected (expected ", show ov
                , ", saw ", show cv, ")"
                ]

storeEvents_ :: forall a. Database -> GUID a -> Int -> [ByteString] -> IO ()
storeEvents_ database guid originatingVersion events = do
  -- Get the current version number of the aggregate.
  versions <- querySql database getCurrentVersionSql [toSQLData guid] [] $ \columns acc ->
    case columns of
      [ SQLInteger v ] -> return (v:acc)
      _ -> badQueryResult [show guid] columns
  let curVer = maximum (0 : versions)

  -- Sanity check current version number.
  when (fromIntegral curVer /= originatingVersion) $
    versionConflict originatingVersion curVer

  -- Update de-normalized version number.
  execSql database updateCurrentVersionSql
    [ toSQLData guid
    , toSQLData $ originatingVersion + length events
    ]

  -- Store the supplied events.
  forM_ (zip [1 + originatingVersion..] events) $ \(v,e) -> do
    execSql database insertEventSql
      [ toSQLData guid
      , toSQLData v
      , toSQLData e
      ]

retrieveEvents_ :: Database -> GUID a -> Int -> IO (Int,[ByteString])
retrieveEvents_ database guid v0 = do
  -- Find events.
  results <- fmap reverse $ querySql database selectEventsSql [toSQLData guid, toSQLData v0] [] $ \columns acc -> do
    case columns of
      [SQLInteger version, SQLBlob eventData] ->
        return $ (version, eventData) : acc
      _ ->
        badQueryResult [show guid] columns

  -- Find the max version number.
  let maxVersion = maximum $ (:) (fromIntegral v0) $ map fst results
  return (fromIntegral maxVersion, map snd results)

readAllEvents_ :: Database -> Int -> Int -> (ByteString -> IO ()) -> IO ()
readAllEvents_ database minVersion maxVersion handler = do
  querySql database selectAllEventsSql [toSQLData minVersion, toSQLData maxVersion] undefined $ \columns _ -> do
    case columns of
      [SQLInteger version, SQLBlob eventData] -> do
        handler eventData
      _ ->
        badQueryResult [] columns
  return ()

writeSnapshot_ :: Database -> GUID a -> (Int, ByteString) -> IO ()
writeSnapshot_ database guid (v,a) = do
  execSql database writeSnapshotSql
    [ toSQLData guid
    , toSQLData a
    , toSQLData v
    ]

getLatestSnapshot_ :: Database -> GUID a -> IO (Maybe (Int, ByteString))
getLatestSnapshot_ database guid = do
  querySql database selectSnapshotSql [toSQLData guid] Nothing $ \columns _ -> do
    case columns of
      [SQLBlob aggData, SQLInteger version] ->
        return $ Just (fromIntegral version, aggData)
      _ ->
        badQueryResult [show guid] columns

withTransaction_ :: forall a . Database -> IO a -> IO a
withTransaction_ database action = do
  beginTransaction database
  onException runAction tryRollback
  where
    runAction = do
      r <- action
      commitTransaction database
      return r

    tryRollback =
      -- Try rollback while discarding exception; we want to preserve
      -- original exception.
      catch (rollbackTransaction database) (\(_::SomeException) -> return ())

-- | Open an SQLite3-based event store using the named SQLite database file.
-- The database file is created if it does not exist.
openSqliteEventStore :: String -> IO EventStore
openSqliteEventStore databaseFileName = do
  -- Create the database.
  database <- SQL.open databaseFileName
  -- Set up tables.
  execSql database createEventsSql []
  execSql database createAggregateVersionsSql []
  execSql database createSnapshotSql []
  -- Return event store.
  return $ EventStore { storeEvents = storeEvents_ database
                      , retrieveEvents = retrieveEvents_ database
                      , readAllEvents = readAllEvents_ database
                      , writeSnapshot = writeSnapshot_ database
                      , getLatestSnapshot = getLatestSnapshot_ database
                      , withTransaction = withTransaction_ database
                      , closeEventStore = SQL.close database
                      }