{-# LANGUAGE ScopedTypeVariables, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-} ----------------------------------------------------------------------------- -- | -- Module : Control.Concurrent.AdvSTM -- Copyright : (c) Chris Kuklewicz 2006, Peter Robinson 2009 -- License : BSD3 -- -- Maintainer : Peter Robinson -- Stability : experimental -- Portability : non-portable (requires STM) -- -- Extends Control.Concurrent.STM with IO hooks -- ----------------------------------------------------------------------------- module Control.Concurrent.AdvSTM( -- * Class MonadAdvSTM MonadAdvSTM( onCommitWith , onCommit , unsafeRetryWith , orElse , retry , check , alwaysSucceeds , always , catchSTM , liftAdv , newTVar , readTVarAsync , writeTVarAsync , readTVar , writeTVar ) -- * Monad AdvSTM , AdvSTM , atomically , unsafeIOToSTM , handleSTM , debugAdvSTM , debugMode ) where import Prelude hiding (catch) import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..)) import Control.Monad.AdvSTM.Class( MonadAdvSTM(..), handleSTM, TVar( TVar ), onCommitLock, currentTid, valueTVar ) -- import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT) import Control.Exception(Exception,throw,catch,SomeException,fromException,try,block,Deadlock(..),finally) import Control.Monad(mplus,when,liftM,ap,unless) import Control.Monad.Error(MonadError(..)) import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId,throwTo) import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan) import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar) -- import Control.Concurrent.STM.TChan(TChan,writeTChan) import Control.Concurrent.MVar(MVar,newMVar,takeMVar,tryTakeMVar,putMVar,tryPutMVar,swapMVar) import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check,always,alwaysSucceeds) import qualified Control.Concurrent.STM.TVar as STVar -- (TVar,newTVarIO,readTVar,writeTVar) import qualified GHC.Conc as G -- unsafeIOToSTM import Data.IORef(newIORef,readIORef,writeIORef) import Data.Maybe(isJust,Maybe,fromJust) import Control.Monad.Reader(MonadReader(..),ReaderT(ReaderT),runReaderT,lift,asks) -------------------------------------------------------------------------------- instance MonadAdvSTM (AdvSTM) where onCommitWith ioclosure = do commitCl <- AdvSTM $ asks commitClosure liftAdv $ STVar.writeTVar commitCl ioclosure onCommit ioaction = do commitVar <- AdvSTM $ asks commitTVar commitActions <- liftAdv $ STVar.readTVar commitVar liftAdv $ STVar.writeTVar commitVar $ ioaction : commitActions -- unsafeOnRetry ioaction = do -- retryVar <- AdvSTM $ asks retryMVar -- liftAdv . unsafeIOToSTM $ do -- may'retryFun <- tryTakeMVar retryVar -- let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun -- putMVar retryVar $! retryFun orElse = mplus retry = liftAdv S.retry check = liftAdv . S.check alwaysSucceeds inv = unlift inv >>= liftAdv . S.alwaysSucceeds always inv = unlift inv >>= liftAdv . S.always catchSTM action handler = do action' <- unlift action handler' <- unlift1 handler let handler'' e = case fromException e of Nothing -> throw e Just e' -> handler' e' liftAdv $ S.catchSTM action' handler'' liftAdv = AdvSTM . lift -- | See 'STVar.newTVar' newTVar a = TVar `liftM` liftAdv (STVar.newTVar a) `ap` liftAdv (newTMVar ()) `ap` liftAdv (STVar.newTVar Nothing) -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are -- complete. See 'onCommit' for details. writeTVar tvar a = do commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar) -- Get ThreadID of current transaction: curTid <- AdvSTM $ asks transThreadId storedTid <- liftAdv $ STVar.readTVar (currentTid tvar) case commitLock of Nothing -> if isJust storedTid && fromJust storedTid == curTid then throw Deadlock -- Can't write the TVar in the onCommit phase else retry Just _ -> do unless (isJust storedTid && (fromJust storedTid == curTid)) $ do -- First write access, update the ThreadId: liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid -- Add this TVar to the onCommit-listener list: lsTVar <- AdvSTM $ asks listeners ls <- liftAdv $ STVar.readTVar lsTVar -- Remember the old value for rollback: oldval <- liftAdv $ STVar.readTVar (valueTVar tvar) liftAdv $ STVar.writeTVar lsTVar $ (onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls liftAdv $ STVar.writeTVar (valueTVar tvar) a liftAdv $ putTMVar (onCommitLock tvar) () writeTVarAsync tvar = liftAdv . STVar.writeTVar (valueTVar tvar) -------------------------------------------------------------------------------- -- | Reads a value from a TVar. Blocks until the IO onCommit action(s) of -- the corresponding transaction are complete. -- See 'onCommit' for a more detailed description of this behaviour. readTVar tvar = do commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar) case commitLock of Nothing -> do storedTid <- liftAdv $ STVar.readTVar (currentTid tvar) curTid <- AdvSTM $ asks transThreadId if isJust storedTid && fromJust storedTid == curTid then throw Deadlock else retry Just _ -> do result <- liftAdv $ STVar.readTVar $ valueTVar tvar liftAdv $ putTMVar (onCommitLock tvar) () return result readTVarAsync = liftAdv . STVar.readTVar . valueTVar -- | Forks a separate thread to run the IO action and then retries the transaction. unsafeRetryWith io = do -- unsafeOnRetry io >> retry doneMVar <- AdvSTM $ asks retryDoneMVar unsafeIOToSTM $ forkIO $ (do val <- takeMVar doneMVar case val of Nothing -> return () Just _ -> io ) `finally` tryPutMVar doneMVar (Just ()) retry unsafeIOToSTM = liftAdv . G.unsafeIOToSTM -------------------------------------------------------------------------------- -- | See 'S.atomically' atomically :: AdvSTM a -> IO a atomically (AdvSTM action) = do let debugging = False -- Switching this to True occasionally causes deadlocks (b/c unsafeIO)! debugTVar <- STVar.newTVarIO debugging tid <- myThreadId -- ThreadId for controlling TVar access debug "***************************************************************" 0 debugging debug (show (tid,"Starting transaction...")) 1000000 debugging -- Building the Reader monad environment commitVar <- STVar.newTVarIO [] -- IO actions to be run if the transaction commits commitCl <- STVar.newTVarIO (sequence_) -- IO action that \"contains\" the commit IO actions. retryDoneMVar <- newMVar $ Just () commitListeners <- STVar.newTVarIO [] let env = Env { commitTVar = commitVar :: STVar.TVar [IO ()] -- -> IO ()) , commitClosure = commitCl , retryDoneMVar = retryDoneMVar :: MVar (Maybe ()) -- (IO () -> IO ()) , transThreadId = tid :: ThreadId , listeners = commitListeners :: STVar.TVar [(TMVar (),TVarValue)] , debugModeVar = debugTVar } let stopRetryWith = swapMVar retryDoneMVar Nothing let wrappedAction = runReaderT action env -- `S.orElse` check'retry -- Block exceptions from other threads for the rest of 'atomically' block $ do result <- S.atomically $ do debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging result <- wrappedAction debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging ls <- STVar.readTVar commitListeners -- Notify the TPVars that we're entering onCommit mode: debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging mapM_ (\(l,_) -> takeTMVar l -- tell the TVar that we're going into onCommit mode ) ls return result let rollbackOnCommit = do debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging S.atomically $ do ls <- STVar.readTVar commitListeners mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do STVar.writeTVar oldValTVar oldVal -- smthg went wrong, restore the old value putTMVar l () -- ...and unblock the TVar ) ls -- Wait for the retryWith thread(s) to be done before running the onCommit -- actions. stopRetryWith -- Now try to run the onCommit IO actions: commitAcs <- liftM reverse $ S.atomically $ STVar.readTVar commitVar commitClAc <- S.atomically $ STVar.readTVar commitCl commitClAc [ ac | ac <- commitAcs ] `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e) -- Everything's ok, so notify and unblock the TVars: debug (show (tid,"*********************")) 1000000 debugging debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging S.atomically $ do ls <- STVar.readTVar commitListeners mapM_ (\(l,_) -> putTMVar l () ) ls debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging debug "************************************************************" 0 debugging return result -------------------------------------------------------------------------------- -- | Switches the debug mode on or off. -- /WARNING:/ Can lead to deadlocks! debugMode :: Bool -> AdvSTM () debugMode switch = do debVar <- AdvSTM $ asks debugModeVar liftAdv $ STVar.writeTVar debVar switch -- | Uses unsafeIOToSTM to output the Thread Id and a message and delays -- for a given number of time. -- /WARNING:/ Can lead to deadlocks! debugAdvSTM :: String -> Int -> AdvSTM () debugAdvSTM msg delay = do debVar <- AdvSTM $ asks debugModeVar debugging <- liftAdv $ STVar.readTVar debVar tid <- AdvSTM $ asks transThreadId liftAdv $ debugSTM (show (tid,msg)) delay debugging -------------------------------------------------------------------------------- -- The following functions are not exported as they aren't needed by -- the enduser runWith :: Env -> AdvSTM t -> S.STM t runWith env (AdvSTM action) = runReaderT action env unlifter :: AdvSTM (AdvSTM a -> S.STM a) unlifter = do env <- AdvSTM ask return (runWith env) unlift :: AdvSTM a -> AdvSTM (S.STM a) unlift f = do u <- unlifter return (u f) unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a) unlift1 f = do u <- unlifter return (u . f) -- WARNING: Can lead to deadlocks! debugSTM :: String -> Int -> Bool -> S.STM () debugSTM msg delay debugging = when debugging $ G.unsafeIOToSTM $ putStrLn msg >> threadDelay delay -- WARNING: Can lead to deadlocks! debug :: String -> Int -> Bool -> IO () debug msg delay debugging = when debugging $ putStrLn msg >> threadDelay delay {- instance forall a. (Exception e, a) => MonadError e (AdvSTM a) where throwError = throw catchError = catchSTM -}