----------------------------------------------------------------------------- -- | -- Module : Control.Concurrent.AdvSTM -- Copyright : (c) Chris Kuklewicz 2006, Peter Robinson 2009 -- License : BSD3 -- -- Maintainer : Peter Robinson -- Stability : experimental -- Portability : non-portable (requires STM) -- -- Extends Control.Concurrent.STM with IO hooks -- ----------------------------------------------------------------------------- -- TODO: remove type class or add readTVar/writeTVar to type class -- Try to change type of onCommit/retryWith to MonadIO! -- module Control.Concurrent.AdvSTM( -- * Class MonadAdvSTM MonadAdvSTM( onCommit , onRetry , orElse , retry , check , alwaysSucceeds , always -- , runAtomic , catchSTM , liftAdv , newTVar , readTVar , writeTVar ) -- * Monad AdvSTM , AdvSTM , retryWith , atomically , unsafeIOToAdvSTM , debugAdvSTM , debugMode ) where import Prelude hiding (catch) import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..)) import Control.Monad.AdvSTM.Class( MonadAdvSTM(..), TVar( TVar ), onCommitLock, currentTid, valueTVar ) -- import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT) import Control.Exception(throw,catch,SomeException,fromException,try,block,Deadlock(..)) import Control.Monad(mplus,when,liftM,ap,unless) import Control.Monad.Reader(MonadReader(..),ReaderT,runReaderT,lift,asks) import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId) import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan) import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar) -- import Control.Concurrent.STM.TChan(TChan,writeTChan) import Control.Concurrent.MVar(MVar,newEmptyMVar,takeMVar,tryTakeMVar,putMVar) import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check,always,alwaysSucceeds) import qualified Control.Concurrent.STM.TVar as STVar -- (TVar,newTVarIO,readTVar,writeTVar) import GHC.Conc(unsafeIOToSTM) import Data.IORef(newIORef,readIORef,writeIORef) import Data.Maybe(isJust,Maybe,fromJust) -------------------------------------------------------------------------------- instance MonadAdvSTM AdvSTM where onCommit ioaction = do commitVar <- AdvSTM $ asks commitTVar commitFun <- liftAdv $ STVar.readTVar commitVar liftAdv $ STVar.writeTVar commitVar $ commitFun . (ioaction >>) onRetry ioaction = do retryVar <- AdvSTM $ asks retryMVar liftAdv . unsafeIOToSTM $ do may'retryFun <- tryTakeMVar retryVar let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun putMVar retryVar $! retryFun orElse = mplus retry = liftAdv S.retry check = liftAdv . S.check alwaysSucceeds inv = unlift inv >>= liftAdv . S.alwaysSucceeds always inv = unlift inv >>= liftAdv . S.always -- runAtomic = atomically catchSTM action handler = do action' <- unlift action handler' <- unlift1 handler let handler'' = \e -> case fromException e of Nothing -> throw e Just e' -> handler' e' liftAdv $ S.catchSTM action' handler'' liftAdv = AdvSTM . lift -- | See 'STVar.newTVar' newTVar a = TVar `liftM` (liftAdv $ STVar.newTVar a) `ap` (liftAdv $ newTMVar ()) `ap` (liftAdv $ STVar.newTVar Nothing) -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are -- complete. See 'onCommit' for details. writeTVar tvar a = do commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar) -- Get ThreadID of current transaction: curTid <- AdvSTM $ asks transThreadId storedTid <- liftAdv $ STVar.readTVar (currentTid tvar) case commitLock of Nothing -> do if isJust storedTid && fromJust storedTid == curTid then throw Deadlock -- No transaction during onCommit-phase! else retry Just _ -> do unless (isJust storedTid && (fromJust storedTid == curTid)) $ do -- First write access, update the ThreadId: liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid -- Add this TVar to the onCommit-listener list: lsTVar <- AdvSTM $ asks listeners ls <- liftAdv $ STVar.readTVar lsTVar -- Remember the old value for rollback: oldval <- liftAdv $ STVar.readTVar (valueTVar tvar) liftAdv $ STVar.writeTVar lsTVar $ (onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls liftAdv $ STVar.writeTVar (valueTVar tvar) a liftAdv $ putTMVar (onCommitLock tvar) () -------------------------------------------------------------------------------- -- | Reads a value from a TVar. Blocks until the IO onCommit action(s) of -- the corresponding transaction are complete. -- See 'onCommit' for a more detailed description of this behaviour. readTVar tvar = do commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar) case commitLock of Nothing -> do storedTid <- liftAdv $ STVar.readTVar (currentTid tvar) curTid <- AdvSTM $ asks transThreadId if isJust storedTid && fromJust storedTid == curTid then throw Deadlock else retry Just _ -> do result <- liftAdv $ STVar.readTVar $ valueTVar tvar liftAdv $ putTMVar (onCommitLock tvar) () return result -------------------------------------------------------------------------------- -- | Adds the IO action to the retry queue and then retries the transaction retryWith :: (Monad m, MonadAdvSTM m) => IO () -> m b retryWith io = onRetry io >> retry -------------------------------------------------------------------------------- -- | See 'S.atomically' atomically :: AdvSTM a -> IO a atomically (AdvSTM action) = do let debugging = False -- Switching this to True occasionally causes deadlocks (b/c unsafeIO)! debugTVar <- STVar.newTVarIO debugging tid <- myThreadId -- ThreadId for controlling TVar access debug "***************************************************************" 0 debugging debug (show (tid,"Starting transaction...")) 1000000 debugging -- Building the Reader monad environment commitVar <- STVar.newTVarIO id -- IO actions to be run if the transaction commits retryVar <- newEmptyMVar -- full if there's something todo before retrying commitListeners <- STVar.newTVarIO [] let env = Env { commitTVar = commitVar :: STVar.TVar (IO () -> IO ()) , retryMVar = retryVar :: MVar (IO () -> IO ()) , transThreadId = tid :: ThreadId , listeners = commitListeners :: STVar.TVar [(TMVar (),TVarValue)] , debugModeVar = debugTVar } -- Setting up communication for the retry-helper thread: retryChanVar <- newIORef (Nothing :: Maybe (Chan (Maybe (IO ())))) retryEndVar <- newEmptyMVar -- Termination signal for the retry-helper thread let check'retry = do unsafeIOToSTM $ do may'todo <- tryTakeMVar retryVar case may'todo of Nothing -> return () Just retryFun -> do may'chan <- readIORef retryChanVar chan <- case may'chan of Nothing -> do chan <- newChan writeIORef retryChanVar (Just chan) spawn'retry'thread (readChan chan) (putMVar retryEndVar ()) return chan Just chan -> return chan writeChan chan $ Just (retryFun (return())) debugSTM (show (tid,"Calling retry now...")) 0 debugging debugSTM (show (tid,"*********************")) 1000000 debugging S.retry let wait'retry'finished = do may'chan <- readIORef retryChanVar case may'chan of Nothing -> return () Just chan -> do -- Write an "EOF" on the channel and block until the helper thread is done: writeChan chan Nothing takeMVar retryEndVar let wrappedAction = runReaderT action env `S.orElse` (check'retry) -- Block interruptions from other threads for the rest of 'atomically' block $ do result <- S.atomically $ do debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging result <- wrappedAction debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging ls <- STVar.readTVar commitListeners -- Notify the TPVars that we're entering onCommit mode: debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging mapM_ (\(l,_) -> takeTMVar l -- tell the TVar that we're going into onCommit mode ) ls return result let rollbackOnCommit = do debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging wait'retry'finished S.atomically $ do ls <- STVar.readTVar commitListeners mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do STVar.writeTVar oldValTVar oldVal -- smthg went wrong, restore the old value putTMVar l () -- ...and unblock the TVar ) ls -- Now try to run the onCommit IO actions: commitFun <- S.atomically $ STVar.readTVar commitVar commitFun (return ()) `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e) -- Everything's ok, so notify and unblock the TVars: debug (show (tid,"*********************")) 1000000 debugging debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging S.atomically $ do ls <- STVar.readTVar commitListeners mapM_ (\(l,_) -> do putTMVar l () ) ls wait'retry'finished debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging debug "************************************************************" 0 debugging return result where -- Helper thread for the retry IO-actions spawn'retry'thread :: IO (Maybe (IO ())) -> IO () -> IO ThreadId spawn'retry'thread nextJob atEndAction = forkIO $ loop where loop = do may'job <- nextJob case may'job of Nothing -> atEndAction Just job -> do res <- try job case res of Left (e::SomeException) -> throw e Right _ -> loop -------------------------------------------------------------------------------- -- | See 'unsafeIOToSTM' unsafeIOToAdvSTM :: IO a -> AdvSTM a unsafeIOToAdvSTM = liftAdv . unsafeIOToSTM -------------------------------------------------------------------------------- -- | Switches the debug mode on or off. -- /WARNING:/ Can lead to deadlocks! debugMode :: Bool -> AdvSTM () debugMode switch = do debVar <- AdvSTM $ asks debugModeVar liftAdv $ STVar.writeTVar debVar switch -- | Uses unsafeIOToSTM to output the Thread Id and a message and delays -- for a given number of time. -- /WARNING:/ Can lead to deadlocks! debugAdvSTM :: String -> Int -> AdvSTM () debugAdvSTM msg delay = do debVar <- AdvSTM $ asks debugModeVar debugging <- liftAdv $ STVar.readTVar debVar tid <- AdvSTM $ asks transThreadId liftAdv $ debugSTM (show (tid,msg)) delay debugging -------------------------------------------------------------------------------- -- The following functions are not exported as they aren't needed by -- the enduser runWith :: Env -> AdvSTM t -> S.STM t runWith env (AdvSTM action) = runReaderT action env unlifter :: AdvSTM (AdvSTM a -> S.STM a) unlifter = do env <- AdvSTM ask return (runWith env) unlift :: AdvSTM a -> AdvSTM (S.STM a) unlift f = do u <- unlifter return (u f) unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a) unlift1 f = do u <- unlifter return (\x -> u (f x)) -- WARNING: Can lead to deadlocks! debugSTM :: String -> Int -> Bool -> S.STM () debugSTM msg delay debugging = when (debugging) $ unsafeIOToSTM $ putStrLn msg >> threadDelay delay -- WARNING: Can lead to deadlocks! debug :: String -> Int -> Bool -> IO () debug msg delay debugging = when (debugging) $ putStrLn msg >> threadDelay delay