{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Database.CQRS.PostgreSQL.StreamFamily ( StreamFamily(..) , makeStreamFamily ) where import Control.Concurrent import Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, takeMVar) import Control.Exception import Control.Monad (void) import Control.Monad.Trans (MonadIO(..)) import Data.List (intersperse) import Data.Proxy (Proxy(..)) import Database.PostgreSQL.Simple ((:.)(..)) import System.Mem.Weak (Weak, deRefWeak, mkWeak) import qualified Control.Concurrent.STM as STM import qualified Control.Monad.Except as Exc import qualified Data.ByteString as BS import qualified Database.PostgreSQL.Simple as PG import qualified Database.PostgreSQL.Simple.FromField as PG.From import qualified Database.PostgreSQL.Simple.FromRow as PG.From import qualified Database.PostgreSQL.Simple.Notification as PG import qualified Database.PostgreSQL.Simple.ToField as PG.To import qualified Database.PostgreSQL.Simple.ToRow as PG.To import qualified Pipes import Database.CQRS.PostgreSQL.Internal (SomeParams(..), handleError) import Database.CQRS.PostgreSQL.Stream import qualified Database.CQRS as CQRS -- | Family of event streams stored in a PostgreSQL relation. -- -- Each stream should have a unique stream identifier and event identifiers must -- be unique within a stream, but not necessarily across them. -- -- 'allNewEvents' starts a new thread which reads notifications on the given -- channel and writes them to a transactional bounded queue (a 'TBQueue') which -- is then consumed by the returned 'Producer'. The maximum size of this queue -- is hard-coded to 100. Should an exception be raised in the listening thread, -- it is thrown back by the producer. data StreamFamily streamId eventId metadata event = StreamFamily { connectionPool :: forall a. (PG.Connection -> IO a) -> IO a , relation :: PG.Query , notificationChannel :: PG.Query , parseNotification :: BS.ByteString -> Either String (streamId, eventId) , streamIdentifierColumn :: PG.Query , eventIdentifierColumn :: PG.Query , metadataColumns :: [PG.Query] , eventColumn :: PG.Query } makeStreamFamily :: (forall a. (PG.Connection -> IO a) -> IO a) -> PG.Query -> PG.Query -> (BS.ByteString -> Either String (streamId, eventId)) -> PG.Query -> PG.Query -> [PG.Query] -> PG.Query -> StreamFamily streamId eventId metadata event makeStreamFamily = StreamFamily instance ( CQRS.Event event , Exc.MonadError CQRS.Error m , MonadIO m , PG.From.FromField eventId , PG.From.FromField streamId , PG.From.FromField (CQRS.EncodingFormat event) , PG.From.FromRow metadata , PG.To.ToField eventId , PG.To.ToField streamId ) => CQRS.StreamFamily m (StreamFamily streamId eventId metadata event) where type StreamType (StreamFamily streamId eventId metadata event) = Stream eventId metadata event type StreamIdentifier (StreamFamily streamId eventId metadata event) = streamId getStream = streamFamilyGetStream allNewEvents = streamFamilyAllNewEvents latestEventIdentifiers = streamFamilyLastEventIdentifiers streamFamilyGetStream :: forall streamId eventId metadata event m. ( MonadIO m , PG.To.ToField streamId , PG.To.ToField eventId ) => StreamFamily streamId eventId metadata event -> streamId -> m (Stream eventId metadata event) streamFamilyGetStream StreamFamily{..} streamId = pure $ makeStream' connectionPool selectQuery insertQuery eventIdentifierColumn where selectQuery :: (PG.Query, PG.Only streamId) selectQuery = (selectQueryTpl, PG.Only streamId) selectQueryTpl :: PG.Query selectQueryTpl = "SELECT " <> eventIdentifierColumn <> ", " <> metadataList <> ", " <> eventColumn <> " FROM " <> relation <> " WHERE " <> streamIdentifierColumn <> " = ? ORDER BY " <> eventIdentifierColumn <> " ASC" insertQuery :: (PG.To.ToField encEvent, PG.To.ToRow metadata) => encEvent -> metadata -> CQRS.ConsistencyCheck eventId -> (PG.Query, SomeParams) insertQuery encEvent metadata cc = let baseParams = PG.Only streamId :. metadata :. PG.Only encEvent (cond, params) = case cc of CQRS.NoConsistencyCheck -> ("", SomeParams baseParams) CQRS.CheckNoEvents -> ( " WHERE NOT EXISTS (SELECT 1 FROM " <> relation <> " WHERE " <> streamIdentifierColumn <> " = ?)" , SomeParams (baseParams :. PG.Only streamId) ) CQRS.CheckLastEvent eventId -> ( " WHERE NOT EXISTS (SELECT 1 FROM " <> relation <> " WHERE " <> streamIdentifierColumn <> " = ? AND " <> eventIdentifierColumn <> " > ?)" , SomeParams (baseParams :. (streamId, eventId)) ) query = "INSERT INTO " <> relation <> "(" <> streamIdentifierColumn <> ", " <> metadataList <> ", " <> eventColumn <> ") SELECT ?, " <> metadataMarks <> ", ?" <> cond <> " RETURNING " <> eventIdentifierColumn in (query, params) metadataList :: PG.Query metadataList = mconcat . intersperse "," $ metadataColumns metadataMarks :: PG.Query metadataMarks = mconcat . intersperse "," . map (const "?") $ metadataColumns data GCKey = GCKey streamFamilyAllNewEvents :: forall streamId eventId metadata event m a. ( CQRS.Event event , Exc.MonadError CQRS.Error m , MonadIO m , PG.From.FromField eventId , PG.From.FromField streamId , PG.From.FromField (CQRS.EncodingFormat event) , PG.From.FromRow metadata , PG.To.ToField eventId , PG.To.ToField streamId ) => StreamFamily streamId eventId metadata event -> m (Pipes.Producer [ ( streamId , Either (eventId, String) (CQRS.EventWithContext eventId metadata event) ) ] m a) streamFamilyAllNewEvents StreamFamily{..} = liftIO $ do let gcKey = GCKey queue <- STM.newTBQueueIO 100 queueWeak <- mkWeak gcKey queue Nothing mvar <- startListeningThread queueWeak pure $ producer gcKey mvar queue where -- Start the listening thread and return an 'MVar' of potential error of the -- thread once it has started. startListeningThread :: Weak (STM.TBQueue (streamId, eventId)) -> IO (MVar CQRS.Error) startListeningThread queueWeak = do errorMVar <- newEmptyMVar startedMVar <- newEmptyMVar _ <- forkFinally (listen startedMVar queueWeak) $ \eErr -> do putMVar startedMVar () -- Unblock it. putMVar errorMVar $ case eErr of Left err -> CQRS.NewEventsStreamingError . show $ err Right err -> err takeMVar startedMVar -- Wait to be sure it started to listen. pure errorMVar -- Entry point of the thread. It gets a connection, starts listening for -- notifications and signals it has started to the parent thread. listen :: MVar () -> Weak (STM.TBQueue (streamId, eventId)) -> IO CQRS.Error listen startedMVar queueWeak = connectionPool $ \conn -> do eThreadRes <- Exc.runExceptT $ do eRes <- liftIO $ (Right <$> PG.execute_ conn ("LISTEN " <> notificationChannel)) `catches` [ handleError (Proxy @PG.FormatError) CQRS.NewEventsStreamingError , handleError (Proxy @PG.SqlError) CQRS.NewEventsStreamingError ] either Exc.throwError (\_ -> pure ()) eRes liftIO $ putMVar startedMVar () handleNotifications conn queueWeak case eThreadRes of Left err -> do _ <- try @SomeException $ PG.execute_ conn "UNLISTEN *" pure err Right () -> pure $ CQRS.NewEventsStreamingError "listening thread terminated" -- Once the listening thread has initialised itself, it runs this code over -- and over again. handleNotifications :: PG.Connection -> Weak (STM.TBQueue (streamId, eventId)) -> Exc.ExceptT CQRS.Error IO () handleNotifications conn queueWeak = do -- Get notifications first even if the queue might not exist anymore. -- Otherwise, this thread would *not* have a pointer to the queue only -- between the recursion call and the first line of the function, i.e. -- maybe not enough time for the garbage collector to run. notif <- liftIO $ PG.getNotification conn mQueue <- liftIO $ deRefWeak queueWeak case mQueue of Nothing -> pure () -- The queue has been garbage collected. Just queue -> do case parseNotification (PG.notificationData notif) of Left err -> Exc.throwError . CQRS.NewEventsStreamingError $ "error decoding notification: " ++ err Right pair -> liftIO . STM.atomically . STM.writeTBQueue queue $ pair handleNotifications conn queueWeak -- Producer that repeatedly checks the listening thread is still alive and -- fetch corresponding events to the notifications. producer :: GCKey -> MVar CQRS.Error -- Error from listening thread. -> STM.TBQueue (streamId, eventId) -> Pipes.Producer [ ( streamId , Either (eventId, String) (CQRS.EventWithContext eventId metadata event) ) ] m a producer gcKey mvar queue = do -- Check the listening thread is still running. mErr <- liftIO $ tryTakeMVar mvar maybe (pure ()) Exc.throwError mErr -- Get some events or none after timeout just in case the listening -- thread died in the meantime (we don't want to block forever.) events <- liftIO $ race (\tmvar -> STM.atomically $ do events <- (:) <$> STM.readTBQueue queue <*> STM.flushTBQueue queue STM.putTMVar tmvar . Right $ events) (\tmvar -> do threadDelay 1000000 -- 1 second STM.atomically . STM.putTMVar tmvar . Right $ []) fetchEvents events producer gcKey mvar queue fetchEvents :: [(streamId, eventId)] -> Pipes.Producer [ ( streamId , Either (eventId, String) (CQRS.EventWithContext eventId metadata event) ) ] m () fetchEvents [] = Pipes.yield [] fetchEvents pairs = do let pairs' = map (uncurry Pair) pairs eRows <- liftIO . connectionPool $ \conn -> (Right <$> PG.query conn fetchQuery (PG.Only (PG.In pairs'))) `catches` [ handleError (Proxy @PG.FormatError) CQRS.EventRetrievalError , handleError (Proxy @PG.QueryError) CQRS.EventRetrievalError , handleError (Proxy @PG.ResultError) CQRS.EventRetrievalError , handleError (Proxy @PG.SqlError) CQRS.EventRetrievalError ] rows <- either Exc.throwError pure eRows Pipes.yield $ map (\(PG.Only streamId :. PG.Only identifier :. metadata :. PG.Only encEvent) -> case CQRS.decodeEvent encEvent of Left err -> (streamId, Left (identifier, err)) Right event -> ( streamId , Right (CQRS.EventWithContext identifier metadata event) ) ) rows fetchQuery :: PG.Query fetchQuery = "SELECT " <> streamIdentifierColumn <> ", " <> eventIdentifierColumn <> ", " <> metadataList <> ", " <> eventColumn <> " FROM " <> relation <> " WHERE (" <> streamIdentifierColumn <> ", " <> eventIdentifierColumn <> ") IN ? ORDER BY " <> eventIdentifierColumn <> " ASC" metadataList :: PG.Query metadataList = mconcat . intersperse "," $ metadataColumns streamFamilyLastEventIdentifiers :: ( Exc.MonadError CQRS.Error m , MonadIO m , PG.From.FromField eventId , PG.From.FromField streamId ) => StreamFamily streamId eventId metadata event -> Pipes.Producer (streamId, eventId) m () streamFamilyLastEventIdentifiers StreamFamily{..} = do eIds <- liftIO . connectionPool $ \conn -> do (Right <$> PG.query_ conn query) `catches` [ handleError (Proxy @PG.FormatError) CQRS.EventRetrievalError , handleError (Proxy @PG.QueryError) CQRS.EventRetrievalError , handleError (Proxy @PG.ResultError) CQRS.EventRetrievalError , handleError (Proxy @PG.SqlError) CQRS.EventRetrievalError ] either Exc.throwError Pipes.each eIds where query :: PG.Query query = "SELECT " <> streamIdentifierColumn <> ", " <> "max(" <> eventIdentifierColumn <> ") FROM " <> relation <> " GROUP BY " <> streamIdentifierColumn -- | Run two threads concurrently and return the result of the first one to -- write to the givem 'TMVar'. If the first thread to finish does so because -- it throws an exception, the exception is rethrown in the main process. race :: (STM.TMVar (Either SomeException a) -> IO ()) -> (STM.TMVar (Either SomeException a) -> IO ()) -> IO a race f g = do tmvar <- STM.newEmptyTMVarIO (tid, tid') <- mask $ \restore -> (,) <$> run f tmvar restore <*> run g tmvar restore eRes <- STM.atomically . STM.readTMVar $ tmvar killThread tid killThread tid' either throw pure eRes where run :: (STM.TMVar (Either SomeException a) -> IO ()) -> STM.TMVar (Either SomeException a) -> (forall b. IO b -> IO b) -> IO ThreadId run h tmvar restore = forkIO $ restore (h tmvar) `catch` (void . STM.atomically . STM.tryPutTMVar tmvar . Left) data Pair a b = Pair a b instance (PG.To.ToField a, PG.To.ToField b) => PG.To.ToField (Pair a b) where toField (Pair x y) = PG.To.Many [ PG.To.Plain "ROW(" , PG.To.toField x , PG.To.Plain "," , PG.To.toField y , PG.To.Plain ")" ]