{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module SafeBuffer ( -- * SafeBufferMonad SafeBufferMonad(..) -- * SafeBufferT , SafeBufferT(..) , runBuffer , tryRunBuffer , execBuffer -- * SafeBufferConcurrentT , SafeBufferConcurrentT(..) , runBufferConcurrently , tryRunBufferConcurrently , execBufferConcurrently ) where import Control.Applicative import Control.Concurrent.STM import Control.Exception.Safe import Control.Monad.Cont import Control.Monad.Except import Control.Monad.Fail import Control.Monad.Reader import Control.Monad.State import Control.Monad.Writer import Control.Monad.Zip import Data.IORef class Monad m => SafeBufferMonad s m | m -> s where -- | Retrieves the buffer's current content. readBuffer :: m s -- | Appends a message to the buffer. writeBuffer :: s -> m () -- | Retrieves the buffer's current content before clearing it. clearBuffer :: m s -- | Applies a given function to the buffer's content. modifyBuffer :: (s -> s) -> m () -------------------------------------------------------------------------------- -- SafeBufferT -------------------------------------------------------------------------------- newtype SafeBufferT s m a = SafeBufferT { runBufferT :: ReaderT (IORef s) m a } deriving ( Functor , Applicative , Alternative , Monad , MonadTrans , MonadIO , MonadReader (IORef s) , MonadWriter r , MonadState s , MonadZip , MonadThrow , MonadCatch , MonadMask , MonadError e , MonadFail , MonadPlus , MonadCont , MonadFix ) instance (MonadIO m, Monoid s) => SafeBufferMonad s (SafeBufferT s m) where readBuffer :: SafeBufferT s m s readBuffer = SafeBufferT $ ReaderT $ \ref -> liftIO $ readIORef ref writeBuffer :: s -> SafeBufferT s m () writeBuffer msg = modifyBuffer (`mappend` msg) clearBuffer :: SafeBufferT s m s clearBuffer = SafeBufferT $ ReaderT $ \ref -> liftIO $ do buffer <- readIORef ref writeIORef ref mempty pure buffer modifyBuffer :: (s -> s) -> SafeBufferT s m () modifyBuffer f = SafeBufferT $ ReaderT $ \ref -> liftIO $ modifyIORef' ref f -- | Runs a buffer and applies a given function to it. -- If any exception occurs while running the buffer, -- the function still runs before the exception is rethrown. runBuffer :: forall s m a b . (MonadIO m, MonadMask m, Monoid s) => (s -> m b) -> SafeBufferT s m a -> m a runBuffer finalize sb = bracket (liftIO $ newIORef mempty) (\ref -> liftIO (readIORef ref) >>= finalize) (\ref -> runReaderT (runBufferT sb) ref) -- | Runs a buffer and returns it, along with either an exception -- or the computation's result. -- -- It purposefully does NOT catch async exceptions. -- To understand why, see -- . tryRunBuffer :: forall e s m a . (MonadIO m, MonadCatch m, Monoid s, Exception e) => SafeBufferT s m a -> m (s, Either e a) tryRunBuffer sb = do ref <- liftIO $ newIORef mempty result <- try $ runReaderT (runBufferT sb) ref buffer <- liftIO $ readIORef ref pure (buffer, result) -- | Runs a buffer and swallow exceptions of type `e`. -- -- It purposefully does NOT catch async exceptions. -- To understand why, see -- . execBuffer :: forall e s m a . (MonadIO m, MonadCatch m, Monoid s, Exception e) => SafeBufferT s m a -> m s execBuffer sb = do ref <- liftIO $ newIORef mempty catch @m @e (runReaderT (runBufferT sb) ref >> liftIO (readIORef ref)) (\_exception -> liftIO $ readIORef ref) -------------------------------------------------------------------------------- -- SafeBufferConcurrentT -------------------------------------------------------------------------------- newtype SafeBufferConcurrentT s m a = SafeBufferConcurrentT { runBufferConcurrentT :: ReaderT (TVar s) m a } deriving ( Functor , Applicative , Alternative , Monad , MonadTrans , MonadIO , MonadReader (TVar s) , MonadWriter r , MonadState s , MonadZip , MonadThrow , MonadCatch , MonadMask , MonadError e , MonadFail , MonadPlus , MonadCont , MonadFix ) instance (MonadIO m, Monoid s) => SafeBufferMonad s (SafeBufferConcurrentT s m) where readBuffer :: SafeBufferConcurrentT s m s readBuffer = SafeBufferConcurrentT $ ReaderT $ \tvar -> liftIO $ readTVarIO tvar writeBuffer :: s -> SafeBufferConcurrentT s m () writeBuffer msg = modifyBuffer (`mappend` msg) clearBuffer :: SafeBufferConcurrentT s m s clearBuffer = SafeBufferConcurrentT $ ReaderT $ \tvar -> liftIO $ atomically $ swapTVar tvar mempty modifyBuffer :: (s -> s) -> SafeBufferConcurrentT s m () modifyBuffer f = SafeBufferConcurrentT $ ReaderT $ \tvar -> liftIO $ atomically $ modifyTVar' tvar f -- | Runs a buffer that can be safely shared accross threads and applies a given function to it. -- If an exception occurs while running the buffer, -- the function still runs before the exception is rethrown. runBufferConcurrently :: forall s m a b . (MonadIO m, MonadMask m, Monoid s) => (s -> m b) -> SafeBufferConcurrentT s m a -> m a runBufferConcurrently finalize sb = bracket (liftIO $ newTVarIO mempty) (\tvar -> liftIO (readTVarIO tvar) >>= finalize) (\tvar -> runReaderT (runBufferConcurrentT sb) tvar) -- | Runs a buffer that can be safely shared accross threads and returns it, along with either an exception -- or the computation's result. -- -- It purposefully does NOT catch async exceptions. -- To understand why, see -- . tryRunBufferConcurrently :: forall e s m a . (MonadIO m, MonadCatch m, Monoid s, Exception e) => SafeBufferConcurrentT s m a -> m (s, Either e a) tryRunBufferConcurrently sb = do tvar <- liftIO $ newTVarIO mempty result <- try $ runReaderT (runBufferConcurrentT sb) tvar buffer <- liftIO $ readTVarIO tvar pure (buffer, result) -- | Runs a buffer that can be safely shared accross threads, and swallows exceptions of type `e`. -- -- It purposefully does NOT catch async exceptions. -- To understand why, see -- . execBufferConcurrently :: forall e s m a . (MonadIO m, MonadCatch m, Monoid s, Exception e) => SafeBufferConcurrentT s m a -> m s execBufferConcurrently sb = do tvar <- liftIO $ newTVarIO mempty catch @m @e (runReaderT (runBufferConcurrentT sb) tvar >> liftIO (readTVarIO tvar)) (\_exception -> liftIO $ readTVarIO tvar)