{-| Implementation of an SQLite3-based event store. -} module Data.CQRS.EventStore.Sqlite3 ( openSqliteEventStore , closeEventStore -- Re-export for convenience ) where import Control.Exception (catch, bracket, onException, SomeException) import Control.Monad (when, forM_) import Data.ByteString (ByteString) import Data.CQRS.EventStore import Data.CQRS.GUID (GUID) import qualified Data.CQRS.GUID as G import qualified Database.SQLite3 as SQL import Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..)) import Prelude hiding (catch) -- Convenience class for converting values to SQLData. class ToSQLData a where toSQLData :: a -> SQLData instance ToSQLData (GUID a) where toSQLData = SQLBlob . G.toByteString instance ToSQLData Int where toSQLData = SQLInteger . fromIntegral instance ToSQLData ByteString where toSQLData = SQLBlob -- SQL createEventsSql :: String createEventsSql = "CREATE TABLE IF NOT EXISTS events ( guid BLOB , ev_data BLOB , version INTEGER , gversion INTEGER, PRIMARY KEY (guid, version) );" selectEventsSql :: String selectEventsSql = "SELECT version, ev_data FROM events WHERE guid = ? AND version > ? ORDER BY version ASC;" selectAllEventsSql :: String selectAllEventsSql = "SELECT version, ev_data FROM events WHERE gversion >= ? AND gversion < ? ORDER BY gversion ASC;" insertEventSql :: String insertEventSql = "INSERT INTO events ( guid, version, ev_data, gversion ) VALUES (?, ?, ?, COALESCE((SELECT MAX(gversion) FROM events ), 0) + 1);" createAggregateVersionsSql :: String createAggregateVersionsSql = "CREATE TABLE IF NOT EXISTS versions ( guid BLOB PRIMARY KEY , version INTEGER );" getCurrentVersionSql :: String getCurrentVersionSql = "SELECT version FROM versions WHERE guid = ?;" updateCurrentVersionSql :: String updateCurrentVersionSql = "INSERT OR REPLACE INTO versions ( guid, version ) VALUES (?,?);" createSnapshotSql :: String createSnapshotSql = "CREATE TABLE IF NOT EXISTS snapshots ( guid BLOB PRIMARY KEY , data BLOB , version INTEGER );" writeSnapshotSql :: String writeSnapshotSql = "INSERT OR REPLACE INTO snapshots ( guid , data, version ) VALUES ( ?, ?, ? );" selectSnapshotSql :: String selectSnapshotSql = "SELECT data, version FROM snapshots WHERE guid = ?;" beginTransaction :: Database -> IO () beginTransaction database = execSql database "BEGIN TRANSACTION;" [] commitTransaction :: Database -> IO () commitTransaction database = execSql database "COMMIT TRANSACTION;" [] rollbackTransaction :: Database -> IO () rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" [] withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a withSqlStatement database sql parameters action = bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do SQL.bind statement parameters action statement execSql :: Database -> String -> [SQLData] -> IO () execSql database sql parameters = withSqlStatement database sql parameters $ \stmt -> do _ <- SQL.step stmt return () querySql :: Database -> String -> [SQLData] -> a -> ([SQLData] -> a -> IO a) -> IO a querySql database sql parameters a0 reader = withSqlStatement database sql parameters go where go statement = loop a0 where loop acc = do res <- SQL.step statement case res of Done -> return acc Row -> do cols <- SQL.columns statement acc' <- reader cols acc loop acc' badQueryResult :: [String] -> [SQLData] -> IO b badQueryResult params columns = fail $ concat ["Invalid query result shape. Params: ", show params, ". Result columns: ", show columns] versionConflict :: (Show a, Show b) => a -> b -> IO c versionConflict ov cv = fail $ concat [ "Version conflict detected (expected ", show ov , ", saw ", show cv, ")" ] storeEvents_ :: forall a. Database -> GUID a -> Int -> [ByteString] -> IO () storeEvents_ database guid originatingVersion events = do -- Get the current version number of the aggregate. versions <- querySql database getCurrentVersionSql [toSQLData guid] [] $ \columns acc -> case columns of [ SQLInteger v ] -> return (v:acc) _ -> badQueryResult [show guid] columns let curVer = maximum (0 : versions) -- Sanity check current version number. when (fromIntegral curVer /= originatingVersion) $ versionConflict originatingVersion curVer -- Update de-normalized version number. execSql database updateCurrentVersionSql [ toSQLData guid , toSQLData $ originatingVersion + length events ] -- Store the supplied events. forM_ (zip [1 + originatingVersion..] events) $ \(v,e) -> do execSql database insertEventSql [ toSQLData guid , toSQLData v , toSQLData e ] retrieveEvents_ :: Database -> GUID a -> Int -> IO (Int,[ByteString]) retrieveEvents_ database guid v0 = do -- Find events. results <- fmap reverse $ querySql database selectEventsSql [toSQLData guid, toSQLData v0] [] $ \columns acc -> do case columns of [SQLInteger version, SQLBlob eventData] -> return $ (version, eventData) : acc _ -> badQueryResult [show guid] columns -- Find the max version number. let maxVersion = maximum $ (:) (fromIntegral v0) $ map fst results return (fromIntegral maxVersion, map snd results) readAllEvents_ :: Database -> Int -> Int -> (ByteString -> IO ()) -> IO () readAllEvents_ database minVersion maxVersion handler = do querySql database selectAllEventsSql [toSQLData minVersion, toSQLData maxVersion] undefined $ \columns _ -> do case columns of [SQLInteger version, SQLBlob eventData] -> do handler eventData _ -> badQueryResult [] columns return () writeSnapshot_ :: Database -> GUID a -> (Int, ByteString) -> IO () writeSnapshot_ database guid (v,a) = do execSql database writeSnapshotSql [ toSQLData guid , toSQLData a , toSQLData v ] getLatestSnapshot_ :: Database -> GUID a -> IO (Maybe (Int, ByteString)) getLatestSnapshot_ database guid = do querySql database selectSnapshotSql [toSQLData guid] Nothing $ \columns _ -> do case columns of [SQLBlob aggData, SQLInteger version] -> return $ Just (fromIntegral version, aggData) _ -> badQueryResult [show guid] columns withTransaction_ :: forall a . Database -> IO a -> IO a withTransaction_ database action = do beginTransaction database onException runAction tryRollback where runAction = do r <- action commitTransaction database return r tryRollback = -- Try rollback while discarding exception; we want to preserve -- original exception. catch (rollbackTransaction database) (\(_::SomeException) -> return ()) -- | Open an SQLite3-based event store using the named SQLite database file. -- The database file is created if it does not exist. openSqliteEventStore :: String -> IO EventStore openSqliteEventStore databaseFileName = do -- Create the database. database <- SQL.open databaseFileName -- Set up tables. execSql database createEventsSql [] execSql database createAggregateVersionsSql [] execSql database createSnapshotSql [] -- Return event store. return $ EventStore { storeEvents = storeEvents_ database , retrieveEvents = retrieveEvents_ database , readAllEvents = readAllEvents_ database , writeSnapshot = writeSnapshot_ database , getLatestSnapshot = getLatestSnapshot_ database , withTransaction = withTransaction_ database , closeEventStore = SQL.close database }