{-| Implementation of an SQLite3-based event store. -} module Data.CQRS.EventStore.Backend.Sqlite3 ( openSqliteEventStore ) where import Control.Exception (catch, bracket, finally, onException, SomeException) import Control.Monad (when, forM_, liftM) import Data.ByteString (ByteString) import Data.CQRS.EventStore.Backend (EventStoreBackend(..)) import Data.CQRS.GUID (GUID) import Data.CQRS.Serialize (decode') import Data.Enumerator (Enumerator, Iteratee(..), tryIO, continue, (>>==), ($=), run_, Stream(..)) import qualified Data.Enumerator.List as EL import Data.Enumerator.Internal (checkContinue0) import Data.Serialize (encode) import qualified Database.SQLite3 as SQL import Database.SQLite3 (Database, Statement, SQLData(..), StepResult(..)) import Prelude hiding (catch) -- Convenience class for converting values to SQLData. class ToSQLData a where toSQLData :: a -> SQLData instance ToSQLData (GUID a) where toSQLData = SQLBlob . encode instance ToSQLData Int where toSQLData = SQLInteger . fromIntegral instance ToSQLData ByteString where toSQLData = SQLBlob -- SQL createEventsSql :: String createEventsSql = "CREATE TABLE IF NOT EXISTS events ( guid BLOB , ev_data BLOB , version INTEGER , gversion INTEGER, PRIMARY KEY (guid, version) );" selectEventsSql :: String selectEventsSql = "SELECT version, ev_data FROM events WHERE guid = ? AND version > ? ORDER BY version ASC;" enumerateAllEventsSql :: String enumerateAllEventsSql = "SELECT gversion, guid, version, ev_data FROM events WHERE gversion >= ? ORDER BY gversion ASC;" insertEventSql :: String insertEventSql = "INSERT INTO events ( guid, version, ev_data, gversion ) VALUES (?, ?, ?, ?);" createAggregateVersionsSql :: String createAggregateVersionsSql = "CREATE TABLE IF NOT EXISTS versions ( guid BLOB PRIMARY KEY , version INTEGER );" getCurrentVersionSql :: String getCurrentVersionSql = "SELECT version FROM versions WHERE guid = ?;" getLatestVersionSql :: String getLatestVersionSql = "SELECT COALESCE(MAX(gversion), 0) FROM events;" updateCurrentVersionSql :: String updateCurrentVersionSql = "INSERT OR REPLACE INTO versions ( guid, version ) VALUES (?,?);" createSnapshotSql :: String createSnapshotSql = "CREATE TABLE IF NOT EXISTS snapshots ( guid BLOB PRIMARY KEY , data BLOB , version INTEGER );" writeSnapshotSql :: String writeSnapshotSql = "INSERT OR REPLACE INTO snapshots ( guid , data, version ) VALUES ( ?, ?, ? );" selectSnapshotSql :: String selectSnapshotSql = "SELECT data, version FROM snapshots WHERE guid = ?;" beginTransaction :: Database -> IO () beginTransaction database = execSql database "BEGIN TRANSACTION;" [] commitTransaction :: Database -> IO () commitTransaction database = execSql database "COMMIT TRANSACTION;" [] rollbackTransaction :: Database -> IO () rollbackTransaction database = execSql database "ROLLBACK TRANSACTION;" [] withSqlStatement :: Database -> String -> [SQLData] -> (Statement -> IO a) -> IO a withSqlStatement database sql parameters action = bracket (SQL.prepare database sql) SQL.finalize $ \statement -> do SQL.bind statement parameters action statement execSql :: Database -> String -> [SQLData] -> IO () execSql database sql parameters = withSqlStatement database sql parameters $ \stmt -> do _ <- SQL.step stmt return () enumQueryResults :: Statement -> Enumerator [SQLData] IO b enumQueryResults stmt = checkContinue0 $ \loop k -> do nextResult <- tryIO $ SQL.step stmt case nextResult of Done -> continue k Row -> do cols <- tryIO $ SQL.columns stmt k (Chunks [cols]) >>== loop enumQueryResult :: Database -> String -> [SQLData] -> Enumerator [SQLData] IO b enumQueryResult database sql parameters step = do stmt <- tryIO $ SQL.prepare database sql Iteratee $ finally (do SQL.bind stmt parameters runIteratee $ enumQueryResults stmt step) (SQL.finalize stmt) badQueryResultMsg :: [String] -> [SQLData] -> String badQueryResultMsg params columns = concat ["Invalid query result shape. Params: ", show params, ". Result columns: ", show columns] versionConflict :: (Show a, Show b) => a -> b -> IO c versionConflict ov cv = fail $ concat [ "Version conflict detected (expected ", show ov , ", saw ", show cv, ")" ] storeEvents :: forall a. Database -> GUID a -> Int -> [(ByteString,Int)] -> IO () storeEvents database guid originatingVersion events = do -- Column unpacking. let unpackColumns [ SQLInteger v ] = v unpackColumns columns = error $ badQueryResultMsg [show guid] columns -- Get the current version number of the aggregate. curVer <- run_ $ EL.fold (\x -> max x . unpackColumns) 0 >>== (enumQueryResult database getCurrentVersionSql [toSQLData guid]) -- Sanity check current version number. when (fromIntegral curVer /= originatingVersion) $ versionConflict originatingVersion curVer -- Update de-normalized version number. execSql database updateCurrentVersionSql [ toSQLData guid , toSQLData $ originatingVersion + length events ] -- Store the supplied events. forM_ (zip [1 + originatingVersion..] events) $ \(v,(e,gv)) -> do execSql database insertEventSql [ toSQLData guid , toSQLData v , toSQLData e , toSQLData gv ] retrieveEvents :: Database -> GUID a -> Int -> IO (Int,[ByteString]) retrieveEvents database guid v0 = do -- Unpack the columns into tuples. let unpackColumns [SQLInteger version, SQLBlob eventData] = (version, eventData) unpackColumns columns = error $ badQueryResultMsg [show guid, show v0] columns -- Find events with version numbers. results <- run_ $ EL.consume >>== (enumQueryResult database selectEventsSql [toSQLData guid, toSQLData v0] $= (EL.map unpackColumns)) -- Find the max version number. let maxVersion = maximum $ (:) (fromIntegral v0) $ map fst results return (fromIntegral maxVersion, map snd results) enumerateAllEvents :: forall a b. Database -> Int -> Enumerator (Int,(GUID a, Int, ByteString)) IO b enumerateAllEvents database minVersion = do enumQueryResult database enumerateAllEventsSql [toSQLData minVersion] $= EL.map (\columns -> do case columns of [ SQLInteger gv, SQLBlob g, SQLInteger v, SQLBlob ed ] -> ( fromIntegral gv, (decode' g, fromIntegral v, ed) ) _ -> error $ badQueryResultMsg [show minVersion] columns) writeSnapshot :: Database -> GUID a -> (Int, ByteString) -> IO () writeSnapshot database guid (v,a) = do execSql database writeSnapshotSql [ toSQLData guid , toSQLData a , toSQLData v ] getLatestSnapshot :: Database -> GUID a -> IO (Maybe (Int, ByteString)) getLatestSnapshot database guid = do -- Unpack columns from result. let unpackColumns :: [SQLData] -> Maybe (Int,ByteString) unpackColumns [SQLBlob a, SQLInteger v] = Just (fromIntegral v, a) unpackColumns columns = error $ badQueryResultMsg [show guid] columns -- Run the query. run_ $ EL.fold const Nothing >>== (enumQueryResult database selectSnapshotSql [toSQLData guid] $= (EL.map unpackColumns)) getLatestVersion :: Database -> IO Int getLatestVersion database = do -- Unpack columns from result. let unpackColumns :: [SQLData] -> Int unpackColumns [SQLInteger v] = fromIntegral v unpackColumns columns = error $ badQueryResultMsg [] columns -- Run the query. liftM head $ run_ $ EL.consume >>== (enumQueryResult database getLatestVersionSql [] $= (EL.map unpackColumns)) withTransaction :: forall a . Database -> IO a -> IO a withTransaction database action = do beginTransaction database onException runAction tryRollback where runAction = do r <- action commitTransaction database return r tryRollback = -- Try rollback while discarding exception; we want to preserve -- original exception. catch (rollbackTransaction database) (\(_::SomeException) -> return ()) -- | Open an SQLite3-based event store using the named SQLite database file. -- The database file is created if it does not exist. openSqliteEventStore :: String -> IO EventStoreBackend openSqliteEventStore databaseFileName = do -- Create the database. database <- SQL.open databaseFileName -- Set up tables. execSql database createEventsSql [] execSql database createAggregateVersionsSql [] execSql database createSnapshotSql [] -- Return event store. return $ EventStoreBackend { esbStoreEvents = storeEvents database , esbRetrieveEvents = retrieveEvents database , esbEnumerateAllEvents = enumerateAllEvents database , esbWriteSnapshot = writeSnapshot database , esbGetLatestSnapshot = getLatestSnapshot database , esbGetLatestVersion = getLatestVersion database , esbWithTransaction = withTransaction database , esbCloseEventStoreBackend = SQL.close database }