module Control.Concurrent.AdvSTM(
MonadAdvSTM( onCommit
, onRetry
, orElse
, retry
, check
, alwaysSucceeds
, always
, catchSTM
, liftAdv
, newTVar
, readTVar
, writeTVar
)
, AdvSTM
, retryWith
, atomically
, unsafeIOToAdvSTM
, debugAdvSTM
, debugMode
) where
import Prelude hiding (catch)
import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..))
import Control.Monad.AdvSTM.Class( MonadAdvSTM(..), TVar( TVar ), onCommitLock, currentTid, valueTVar )
import Control.Exception(throw,catch,SomeException,fromException,try,block,Deadlock(..))
import Control.Monad(mplus,when,liftM,ap,unless)
import Control.Monad.Reader(MonadReader(..),ReaderT,runReaderT,lift,asks)
import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId)
import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan)
import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar)
import Control.Concurrent.MVar(MVar,newEmptyMVar,takeMVar,tryTakeMVar,putMVar)
import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check,always,alwaysSucceeds)
import qualified Control.Concurrent.STM.TVar as STVar
import GHC.Conc(unsafeIOToSTM)
import Data.IORef(newIORef,readIORef,writeIORef)
import Data.Maybe(isJust,Maybe,fromJust)
instance MonadAdvSTM AdvSTM where
onCommit ioaction = do
commitVar <- AdvSTM $ asks commitTVar
commitFun <- liftAdv $ STVar.readTVar commitVar
liftAdv $ STVar.writeTVar commitVar $ commitFun . (ioaction >>)
onRetry ioaction = do
retryVar <- AdvSTM $ asks retryMVar
liftAdv . unsafeIOToSTM $ do
may'retryFun <- tryTakeMVar retryVar
let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun
putMVar retryVar $! retryFun
orElse = mplus
retry = liftAdv S.retry
check = liftAdv . S.check
alwaysSucceeds inv = unlift inv >>= liftAdv . S.alwaysSucceeds
always inv = unlift inv >>= liftAdv . S.always
catchSTM action handler = do
action' <- unlift action
handler' <- unlift1 handler
let handler'' = \e -> case fromException e of
Nothing -> throw e
Just e' -> handler' e'
liftAdv $ S.catchSTM action' handler''
liftAdv = AdvSTM . lift
newTVar a = TVar `liftM` (liftAdv $ STVar.newTVar a)
`ap` (liftAdv $ newTMVar ())
`ap` (liftAdv $ STVar.newTVar Nothing)
writeTVar tvar a = do
commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
curTid <- AdvSTM $ asks transThreadId
storedTid <- liftAdv $ STVar.readTVar (currentTid tvar)
case commitLock of
Nothing -> do
if isJust storedTid && fromJust storedTid == curTid
then throw Deadlock
else retry
Just _ -> do
unless (isJust storedTid && (fromJust storedTid == curTid)) $ do
liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid
lsTVar <- AdvSTM $ asks listeners
ls <- liftAdv $ STVar.readTVar lsTVar
oldval <- liftAdv $ STVar.readTVar (valueTVar tvar)
liftAdv $ STVar.writeTVar lsTVar $
(onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls
liftAdv $ STVar.writeTVar (valueTVar tvar) a
liftAdv $ putTMVar (onCommitLock tvar) ()
readTVar tvar = do
commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
case commitLock of
Nothing -> do
storedTid <- liftAdv $ STVar.readTVar (currentTid tvar)
curTid <- AdvSTM $ asks transThreadId
if isJust storedTid && fromJust storedTid == curTid
then throw Deadlock
else retry
Just _ -> do
result <- liftAdv $ STVar.readTVar $ valueTVar tvar
liftAdv $ putTMVar (onCommitLock tvar) ()
return result
retryWith :: (Monad m, MonadAdvSTM m) => IO () -> m b
retryWith io = onRetry io >> retry
atomically :: AdvSTM a -> IO a
atomically (AdvSTM action) = do
let debugging = False
debugTVar <- STVar.newTVarIO debugging
tid <- myThreadId
debug "***************************************************************" 0 debugging
debug (show (tid,"Starting transaction...")) 1000000 debugging
commitVar <- STVar.newTVarIO id
retryVar <- newEmptyMVar
commitListeners <- STVar.newTVarIO []
let env = Env { commitTVar = commitVar :: STVar.TVar (IO () -> IO ())
, retryMVar = retryVar :: MVar (IO () -> IO ())
, transThreadId = tid :: ThreadId
, listeners = commitListeners :: STVar.TVar [(TMVar (),TVarValue)]
, debugModeVar = debugTVar
}
retryChanVar <- newIORef (Nothing :: Maybe (Chan (Maybe (IO ()))))
retryEndVar <- newEmptyMVar
let check'retry = do
unsafeIOToSTM $ do
may'todo <- tryTakeMVar retryVar
case may'todo of
Nothing -> return ()
Just retryFun -> do
may'chan <- readIORef retryChanVar
chan <- case may'chan of
Nothing -> do
chan <- newChan
writeIORef retryChanVar (Just chan)
spawn'retry'thread (readChan chan) (putMVar retryEndVar ())
return chan
Just chan -> return chan
writeChan chan $ Just (retryFun (return()))
debugSTM (show (tid,"Calling retry now...")) 0 debugging
debugSTM (show (tid,"*********************")) 1000000 debugging
S.retry
let wait'retry'finished = do
may'chan <- readIORef retryChanVar
case may'chan of
Nothing -> return ()
Just chan -> do
writeChan chan Nothing
takeMVar retryEndVar
let wrappedAction = runReaderT action env `S.orElse` (check'retry)
block $ do
result <- S.atomically $ do
debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging
result <- wrappedAction
debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging
ls <- STVar.readTVar commitListeners
debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging
mapM_ (\(l,_) ->
takeTMVar l
) ls
return result
let rollbackOnCommit = do
debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging
wait'retry'finished
S.atomically $ do
ls <- STVar.readTVar commitListeners
mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do
STVar.writeTVar oldValTVar oldVal
putTMVar l ()
) ls
commitFun <- S.atomically $ STVar.readTVar commitVar
commitFun (return ()) `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e)
debug (show (tid,"*********************")) 1000000 debugging
debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging
S.atomically $ do
ls <- STVar.readTVar commitListeners
mapM_ (\(l,_) -> do
putTMVar l ()
) ls
wait'retry'finished
debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging
debug "************************************************************" 0 debugging
return result
where
spawn'retry'thread :: IO (Maybe (IO ())) -> IO () -> IO ThreadId
spawn'retry'thread nextJob atEndAction = forkIO $ loop
where loop = do
may'job <- nextJob
case may'job of
Nothing -> atEndAction
Just job -> do
res <- try job
case res of
Left (e::SomeException) -> throw e
Right _ -> loop
unsafeIOToAdvSTM :: IO a -> AdvSTM a
unsafeIOToAdvSTM = liftAdv . unsafeIOToSTM
debugMode :: Bool -> AdvSTM ()
debugMode switch = do
debVar <- AdvSTM $ asks debugModeVar
liftAdv $ STVar.writeTVar debVar switch
debugAdvSTM :: String -> Int -> AdvSTM ()
debugAdvSTM msg delay = do
debVar <- AdvSTM $ asks debugModeVar
debugging <- liftAdv $ STVar.readTVar debVar
tid <- AdvSTM $ asks transThreadId
liftAdv $ debugSTM (show (tid,msg)) delay debugging
runWith :: Env -> AdvSTM t -> S.STM t
runWith env (AdvSTM action) = runReaderT action env
unlifter :: AdvSTM (AdvSTM a -> S.STM a)
unlifter = do
env <- AdvSTM ask
return (runWith env)
unlift :: AdvSTM a -> AdvSTM (S.STM a)
unlift f = do
u <- unlifter
return (u f)
unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a)
unlift1 f = do
u <- unlifter
return (\x -> u (f x))
debugSTM :: String -> Int -> Bool -> S.STM ()
debugSTM msg delay debugging =
when (debugging) $ unsafeIOToSTM $ putStrLn msg >> threadDelay delay
debug :: String -> Int -> Bool -> IO ()
debug msg delay debugging = when (debugging) $ putStrLn msg >> threadDelay delay