{-# LANGUAGE ScopedTypeVariables, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
module Control.Concurrent.AdvSTM(
MonadAdvSTM( onCommitWith
, onCommit
, unsafeRetryWith
, orElse
, retry
, check
, catchSTM
, liftAdv
, newTVar
, readTVarAsync
, writeTVarAsync
, readTVar
, writeTVar
)
, AdvSTM
, atomically
, unsafeIOToSTM
, handleSTM
, debugAdvSTM
, debugMode
) where
import Prelude hiding (catch)
import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..))
import Control.Monad.AdvSTM.Class(
MonadAdvSTM(..), handleSTM, TVar( TVar ), onCommitLock, currentTid, valueTVar )
import Control.Exception(Exception,throw,catch,SomeException,fromException,try,mask_,Deadlock(..),finally)
import Control.Monad(mplus,when,liftM,ap,unless)
import Control.Monad.Except
import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId,throwTo)
import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan)
import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar)
import Control.Concurrent.MVar(MVar,newMVar,takeMVar,tryTakeMVar,putMVar,tryPutMVar,swapMVar)
import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check)
import qualified Control.Concurrent.STM.TVar as STVar
import qualified GHC.Conc as G
import Data.IORef(newIORef,readIORef,writeIORef)
import Data.Maybe(isJust,Maybe,fromJust)
import Control.Monad.Reader(MonadReader(..),ReaderT(ReaderT),runReaderT,lift,asks)
instance MonadAdvSTM (AdvSTM) where
onCommitWith ioclosure = do
commitCl <- AdvSTM $ asks commitClosure
liftAdv $ STVar.writeTVar commitCl ioclosure
onCommit ioaction = do
commitVar <- AdvSTM $ asks commitTVar
commitActions <- liftAdv $ STVar.readTVar commitVar
liftAdv $ STVar.writeTVar commitVar $ ioaction : commitActions
orElse = mplus
retry = liftAdv S.retry
check = liftAdv . S.check
catchSTM action handler = do
action' <- unlift action
handler' <- unlift1 handler
let handler'' e = case fromException e of
Nothing -> throw e
Just e' -> handler' e'
liftAdv $ S.catchSTM action' handler''
liftAdv = AdvSTM . lift
newTVar a = TVar `liftM` liftAdv (STVar.newTVar a)
`ap` liftAdv (newTMVar ())
`ap` liftAdv (STVar.newTVar Nothing)
writeTVar tvar a = do
commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
curTid <- AdvSTM $ asks transThreadId
storedTid <- liftAdv $ STVar.readTVar (currentTid tvar)
case commitLock of
Nothing ->
if isJust storedTid && fromJust storedTid == curTid
then throw Deadlock
else retry
Just _ -> do
unless (isJust storedTid && (fromJust storedTid == curTid)) $ do
liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid
lsTVar <- AdvSTM $ asks listeners
ls <- liftAdv $ STVar.readTVar lsTVar
oldval <- liftAdv $ STVar.readTVar (valueTVar tvar)
liftAdv $ STVar.writeTVar lsTVar $
(onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls
liftAdv $ STVar.writeTVar (valueTVar tvar) a
liftAdv $ putTMVar (onCommitLock tvar) ()
writeTVarAsync tvar = liftAdv . STVar.writeTVar (valueTVar tvar)
readTVar tvar = do
commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
case commitLock of
Nothing -> do
storedTid <- liftAdv $ STVar.readTVar (currentTid tvar)
curTid <- AdvSTM $ asks transThreadId
if isJust storedTid && fromJust storedTid == curTid
then throw Deadlock
else retry
Just _ -> do
result <- liftAdv $ STVar.readTVar $ valueTVar tvar
liftAdv $ putTMVar (onCommitLock tvar) ()
return result
readTVarAsync = liftAdv . STVar.readTVar . valueTVar
unsafeRetryWith io = do
doneMVar <- AdvSTM $ asks retryDoneMVar
unsafeIOToSTM $ forkIO $ (do
val <- takeMVar doneMVar
case val of
Nothing -> return ()
Just _ -> io
) `finally` tryPutMVar doneMVar (Just ())
retry
unsafeIOToSTM = liftAdv . G.unsafeIOToSTM
atomically :: AdvSTM a -> IO a
atomically (AdvSTM action) = do
let debugging = False
debugTVar <- STVar.newTVarIO debugging
tid <- myThreadId
debug "***************************************************************" 0 debugging
debug (show (tid,"Starting transaction...")) 1000000 debugging
commitVar <- STVar.newTVarIO []
commitCl <- STVar.newTVarIO (sequence_)
retryDoneMVar <- newMVar $ Just ()
commitListeners <- STVar.newTVarIO []
let env = Env { commitTVar = commitVar :: STVar.TVar [IO ()]
, commitClosure = commitCl
, retryDoneMVar = retryDoneMVar :: MVar (Maybe ())
, transThreadId = tid :: ThreadId
, listeners = commitListeners :: STVar.TVar [(TMVar (),TVarValue)]
, debugModeVar = debugTVar
}
let stopRetryWith = swapMVar retryDoneMVar Nothing
let wrappedAction = runReaderT action env
mask_ $ do
result <- S.atomically $ do
debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging
result <- wrappedAction
debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging
ls <- STVar.readTVar commitListeners
debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging
mapM_ (\(l,_) ->
takeTMVar l
) ls
return result
let rollbackOnCommit = do
debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging
S.atomically $ do
ls <- STVar.readTVar commitListeners
mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do
STVar.writeTVar oldValTVar oldVal
putTMVar l ()
) ls
stopRetryWith
commitAcs <- liftM reverse $ S.atomically $ STVar.readTVar commitVar
commitClAc <- S.atomically $ STVar.readTVar commitCl
commitClAc [ ac | ac <- commitAcs ]
`catch` (\(e::SomeException) -> rollbackOnCommit >> throw e)
debug (show (tid,"*********************")) 1000000 debugging
debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging
S.atomically $ do
ls <- STVar.readTVar commitListeners
mapM_ (\(l,_) ->
putTMVar l ()
) ls
debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging
debug "************************************************************" 0 debugging
return result
debugMode :: Bool -> AdvSTM ()
debugMode switch = do
debVar <- AdvSTM $ asks debugModeVar
liftAdv $ STVar.writeTVar debVar switch
debugAdvSTM :: String -> Int -> AdvSTM ()
debugAdvSTM msg delay = do
debVar <- AdvSTM $ asks debugModeVar
debugging <- liftAdv $ STVar.readTVar debVar
tid <- AdvSTM $ asks transThreadId
liftAdv $ debugSTM (show (tid,msg)) delay debugging
runWith :: Env -> AdvSTM t -> S.STM t
runWith env (AdvSTM action) = runReaderT action env
unlifter :: AdvSTM (AdvSTM a -> S.STM a)
unlifter = do
env <- AdvSTM ask
return (runWith env)
unlift :: AdvSTM a -> AdvSTM (S.STM a)
unlift f = do
u <- unlifter
return (u f)
unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a)
unlift1 f = do
u <- unlifter
return (u . f)
debugSTM :: String -> Int -> Bool -> S.STM ()
debugSTM msg delay debugging =
when debugging $ G.unsafeIOToSTM $ putStrLn msg >> threadDelay delay
debug :: String -> Int -> Bool -> IO ()
debug msg delay debugging = when debugging $ putStrLn msg >> threadDelay delay