{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Database.CQRS.ReadModel.AggregateStore ( AggregateStore , makeAggregateStore , Response(..) ) where import Control.Monad.Trans (MonadIO(..)) import Data.Hashable (Hashable) import qualified Control.Monad.Except as Exc import qualified Control.Concurrent.STM as STM import qualified Data.HashPSQ as HashPSQ import qualified Data.Time as T import qualified Database.CQRS as CQRS data Response eventId aggregate = Response { lastEventId :: Maybe eventId , aggregate :: aggregate , eventCount :: Int -- ^ Number of events processed in this fetch. , totalEventCount :: Int -- ^ Total number of events making this aggregate. } data AggregateStore streamFamily aggregate = AggregateStore { streamFamily :: streamFamily , aggregator :: CQRS.Aggregator (CQRS.EventWithContext' (CQRS.StreamType streamFamily)) aggregate , initialAggregate :: CQRS.StreamIdentifier streamFamily -> aggregate , cache :: Cache (CQRS.StreamIdentifier streamFamily) (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate , lagTolerance :: T.NominalDiffTime } makeAggregateStore :: MonadIO m => streamFamily -> CQRS.Aggregator (CQRS.EventWithContext' (CQRS.StreamType streamFamily)) aggregate -> (CQRS.StreamIdentifier streamFamily -> aggregate) -> T.NominalDiffTime -- ^ Lag tolerance. -> Int -- ^ Maximum number of elements in the cache. -> m (AggregateStore streamFamily aggregate) makeAggregateStore streamFamily aggregator initialAggregate lagTolerance maxSize = do cache <- liftIO . STM.atomically $ do cachedValues <- STM.newTVar HashPSQ.empty size <- STM.newTVar 0 pure Cache{..} pure AggregateStore{..} instance ( CQRS.StreamFamily m streamFamily , CQRS.Stream m (CQRS.StreamType streamFamily) , Exc.MonadError CQRS.Error m , Hashable (CQRS.StreamIdentifier streamFamily) , MonadIO m , Ord (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) , Ord (CQRS.StreamIdentifier streamFamily) , Show (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) ) => CQRS.ReadModel m (AggregateStore streamFamily aggregate) where type ReadModelQuery (AggregateStore streamFamily aggregate) = CQRS.StreamIdentifier streamFamily type ReadModelResponse (AggregateStore streamFamily aggregate) = Response (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate query = aggregateStoreQuery aggregateStoreQuery :: forall m streamFamily aggregate. ( CQRS.StreamFamily m streamFamily , CQRS.Stream m (CQRS.StreamType streamFamily) , Exc.MonadError CQRS.Error m , Hashable (CQRS.StreamIdentifier streamFamily) , MonadIO m , Ord (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) , Ord (CQRS.StreamIdentifier streamFamily) , Show (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) ) => AggregateStore streamFamily aggregate -> CQRS.StreamIdentifier streamFamily -> m (Response (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate) aggregateStoreQuery AggregateStore{..} streamId = do hpsq <- liftIO . STM.atomically . STM.readTVar . cachedValues $ cache now <- liftIO T.getCurrentTime case HashPSQ.lookup streamId hpsq of Just (lastUpToDateTime, item@CacheItem{..}) -> do if now < T.addUTCTime lagTolerance lastUpToDateTime then pure Response { lastEventId = Just cachedLastEventId , aggregate = cachedAggregate , eventCount = 0 , totalEventCount = cachedEventCount } else getAggregate now (Just item) Nothing -> getAggregate now Nothing where getAggregate :: T.UTCTime -> Maybe (CacheItem (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate) -> m (Response (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate) getAggregate now mPrevious = do let (initAggregate, bounds, eventCount) = case mPrevious of Just CacheItem{..} -> ( cachedAggregate , CQRS.afterEvent cachedLastEventId , cachedEventCount ) Nothing -> (initialAggregate streamId, mempty, 0) stream <- CQRS.getStream streamFamily streamId (aggregate, mEventId, processedEventCount) <- CQRS.runAggregator aggregator stream bounds initAggregate let totalEventCount = eventCount + processedEventCount mkCacheItem lastEventId = CacheItem { cachedAggregate = aggregate , cachedLastEventId = lastEventId , cachedEventCount = totalEventCount } lastEventId <- liftIO $ case (mEventId, mPrevious) of (Nothing, Nothing) -> pure Nothing (Nothing, Just CacheItem{..}) -> do addValueToCache cache streamId now (mkCacheItem cachedLastEventId) pure $ Just cachedLastEventId (Just lastEventId, _) -> do addValueToCache cache streamId now (mkCacheItem lastEventId) pure $ Just lastEventId pure Response { lastEventId , aggregate , eventCount = processedEventCount , totalEventCount } data CacheItem eventId aggregate = CacheItem { cachedAggregate :: aggregate , cachedLastEventId :: eventId , cachedEventCount :: Int } data Cache streamId eventId aggregate = Cache { cachedValues :: STM.TVar (HashPSQ.HashPSQ streamId T.UTCTime (CacheItem eventId aggregate)) , size :: STM.TVar Int , maxSize :: Int } addValueToCache :: ( Hashable streamId , Ord eventId , Ord streamId ) => Cache streamId eventId aggregate -> streamId -> T.UTCTime -> CacheItem eventId aggregate -> IO () addValueToCache Cache{..} streamId now item = STM.atomically $ do hpsq <- STM.readTVar cachedValues currentSize <- STM.readTVar size let (newSize, hpsq') = (\f -> HashPSQ.alter f streamId hpsq) $ \case Nothing -> (currentSize + 1, Just (now, item)) Just current@(_, currentItem) | cachedLastEventId currentItem > cachedLastEventId item -> (currentSize, Just current) | otherwise -> (currentSize, Just (now, item)) (newSize', hpsq'') | newSize > maxSize = (newSize - 1, HashPSQ.deleteMin hpsq') | otherwise = (newSize, hpsq') STM.writeTVar size newSize' STM.writeTVar cachedValues hpsq''