module STM
(STM, TVar, newTVar, newTVarIO, readTVar, writeTVar,
atomically,retry, Atomic, Twi, Safe, gbind, gret,twilight,getInconsistencies, isInconsistent,
reload,safeTwiIO, rewriteTVar, rereadTVar,unsafeTwiIO,newRegion,tryCommit) where
import Prelude hiding (mapM_)
import Counter
import Semaphore
import STMHelpers
import Control.Concurrent.MVar
import Data.IORef
import qualified Data.IntMap as Map
import Data.Maybe
import Unsafe.Coerce
import Monad hiding (mapM_)
import Data.Foldable (mapM_)
import Control.VarStateM hiding (unlock)
type Id = Int
type Timestamp = Int
data AccessFlag = W | WR | R deriving (Show,Eq)
data Twi = Twi
data Atomic = Atomic
data Safe = Safe
data Status = Ok [IO ()] | Bad [IO ()] | NotChecked
data Checked = Incons | Cons
data STM r p q a = STM ((StmState r) -> IO (STMResult r a))
instance Monad (STM r p q) where
return x = STM (\state -> return (Success state x))
(STM tr1) >>= k = STM (\state -> do
stmRes <- tr1 state
case stmRes of
Success newState a ->
let (STM tr2) = k a in
tr2 newState
Retry newState -> return (Retry newState)
Error newState -> return (Error newState)
)
instance Monadish (STM r) where
gret = return
gbind = unsafeCoerce orig_bind where
orig_bind :: STM r p q a -> (a -> STM r p q b) -> STM r p q b
orig_bind = (>>=)
data StmState x = StmState {
tstamp :: Timestamp,
wset :: Heap x,
rset :: Heap x,
regCounter :: Id,
status :: Status }
data STMResult x a = Retry (StmState x)
| Success (StmState x) a
| Error String
atomically :: (forall s. STM s p q a) -> IO a
atomically action = atomically' action
atomically' :: STM s p q a -> IO a
atomically' action = do
state <- initialState
atomically'' action state
where
atomically'' :: STM s p q a -> StmState s -> IO a
atomically'' action state = do
stmResult <- runSTM action state
case stmResult of
Error s -> do
error s
Retry state -> do
case status state of
NotChecked -> return ()
Bad unlocks -> unlockLocks unlocks
Ok unlocks -> unlockLocks unlocks
atomically' action
Success newState res -> do
case status newState of
NotChecked -> do
wlocks <- enterTwilight newState
(valid,_) <- validateReadSet $ newState
if valid
then do
publishAndUnlock $ wset newState
return res
else do
unlockLocks $ wlocks
atomically' action
Bad unlocks -> do
unlockLocks unlocks
atomically' action
Ok _ -> do
publishAndUnlock $ wset newState
return res
initialState :: IO (StmState r)
initialState = do
timestamp <- sampleClock
return (StmState { tstamp = timestamp, wset = Map.empty, rset = Map.empty, regCounter = 0, status = NotChecked} )
runSTM :: STM r p q a -> (StmState r) -> IO (STMResult r a)
runSTM action@(STM stm) state = stm state
retry :: STM r p q a
retry = STM (\stmState -> do
return (Retry stmState))
newTVar :: a -> STM r p p (TVar a)
newTVar x = STM (\stmState -> do
content <- newIORef x
ident <- getUniqueId
time <- newIORef (tstamp stmState)
lck <- newMVar ()
return $ Success stmState (TVar content ident time lck)
)
writeTVar :: Show a => TVar a -> a -> STM r Atomic Atomic (WTwiVar a)
writeTVar tvar@(TVar _ ident _ _ ) x = STM (\stmState -> do
let wset' = Map.insertWith
(\ (MkHeapEntry tvar val ident r t _) _ -> (MkHeapEntry (unsafeCoerce tvar) x ident (unsafeCoerce r) t Nothing))
ident
(MkHeapEntry tvar x ident Nothing (tstamp stmState) Nothing)
(wset stmState)
return $ Success stmState{wset = wset'} (WTwiVar ident))
readTVar :: Show a => TVar a -> Region r a -> STM r Atomic Atomic (a, RTwiVar a)
readTVar tvar@(TVar mvar ident t lck) reg = STM (\stmState -> do
let entry = Map.lookup ident (wset stmState)
case entry of
Just (MkHeapEntry _ x _ _ _ _) -> return $ Success stmState (unsafeCoerce x,RTwiVar ident)
Nothing -> do
let entry = Map.lookup ident (rset stmState)
case entry of
Just (MkHeapEntry _ x _ _ _ _) -> return $ Success stmState (unsafeCoerce x,RTwiVar ident)
Nothing -> do
l <- tryTakeMVar lck
case l of
Nothing -> return $ Retry stmState
Just _ -> do
time <- readIORef t
x <- readIORef mvar
putMVar lck ()
if ((tstamp stmState) < time )
then return $ Retry stmState
else do
let rset' = Map.insert ident
(MkHeapEntry (unsafeCoerce tvar) x ident (Just reg) time Nothing)
(rset stmState)
return $ Success stmState{rset = rset'} (x,RTwiVar ident)
)
is :: RTwiVar a -> TVar a -> Bool
is (RTwiVar ident) (TVar tv ident' _ _) = ident == ident'
twilight :: STM r Atomic Twi Bool
twilight = STM (\stmState -> do
wlocks <- enterTwilight stmState
(valid,rset') <- validateReadSet $ stmState
return (Success stmState{status = if valid then Ok wlocks else Bad wlocks, rset = rset'} valid )
)
rewriteTVar :: Show a => WTwiVar a -> a -> STM r p p ()
rewriteTVar (WTwiVar ident) v = STM (\stmState -> do
let entry = Map.lookup ident (wset stmState)
case entry of
Just (MkHeapEntry tvar val ident r t t') ->
let wset' = Map.insert
ident
(MkHeapEntry (unsafeCoerce tvar) v ident (unsafeCoerce r) t t')
(wset stmState)
in return $ Success stmState{wset = wset'} ()
Nothing ->
return $ Error "You tried to rewrite a TVar which has not been written in the atomic before"
)
rereadTVar :: RTwiVar a -> STM r p p a
rereadTVar (RTwiVar ident) = STM (\stmState -> do
let entry = Map.lookup ident (rset stmState)
case entry of
Just (MkHeapEntry _ x _ _ _ _) -> return $ Success stmState (unsafeCoerce x)
Nothing -> return $ Error "You tried to read a TVar which has not been read in the atomic before"
)
lockReadAndWriteSet :: StmState r -> IO ([IO()], [IO ()])
lockReadAndWriteSet state = do
let writeset = wset state
let readset = rset state
foldM (\(wlocks,rlocks) (ident,MkHeapEntry (TVar mv _ mt lck) _ _ _ _ _) -> do
takeMVar lck
if Map.member ident writeset
then return (putMVar lck ():wlocks,rlocks)
else return (wlocks,putMVar lck ():rlocks)
)
([],[])
(Map.toAscList $ Map.union writeset readset)
rereadAndUnlock :: [IO ()] -> Heap r -> IO (Heap r)
rereadAndUnlock rlocks rset = do
rset' <- mapM (\(ident',(MkHeapEntry tvar@(TVar mv ident mt _ ) _ _ r t _)) -> do
newVal <- readIORef mv
newTime <- readIORef mt
return $ (ident',MkHeapEntry tvar (unsafeCoerce newVal) ident r t (Just newTime))
)
( Map.toList rset)
mapM_ (\x -> do x) rlocks
return $ Map.fromList rset'
exposed :: IO a -> StmState r -> IO (StmState r,a)
exposed action stmState = do
getPos semaphore
case status stmState of
Ok unlocks -> unlockLocks unlocks
Bad unlocks -> unlockLocks unlocks
result <- action
(consistent,_) <- validateReadSet $ stmState
if consistent
then do
writelocks <- lockWriteSet $ wset stmState
putPos semaphore
return (stmState{status = Ok writelocks},result)
else do
(writelocks,readlocks) <- lockReadAndWriteSet stmState
putPos semaphore
rset' <- rereadAndUnlock readlocks (rset stmState)
return (stmState{rset = rset', status = Ok writelocks}, result)
reload :: STM r Twi Safe ()
reload = STM (\stmState -> do
case status stmState of
Ok _ -> return $ Success stmState ()
Bad _ -> do
(newState,_) <- exposed (return ()) stmState
return $ Success newState () )
tryCommit :: STM r Twi Safe ()
tryCommit = STM (\stmState -> do
case status stmState of
Ok _ -> return $ Success stmState ()
Bad _ -> return $ Retry stmState )
safeTwiIO :: IO a -> STM r Safe Safe a
safeTwiIO action = STM(\stmState -> do
result <- action
return $ Success stmState result)
unsafeTwiIO :: IO a -> STM r p p a
unsafeTwiIO action = STM (\stmState -> do
a <- action
return $ Success stmState a
)
getInconsistencies :: Region r a -> STM r Safe Safe [(RTwiVar a,Maybe (WTwiVar a))]
getInconsistencies reg@(Region ident _) = STM (\stmState -> do
let regionlist = filter (\ (_,(MkHeapEntry _ _ _ r _ _)) -> case r of
Just (Region ident' _) -> ident == ident' ) (Map.toList $ rset stmState)
let inconslist = filter (\(_,MkHeapEntry tvar@(TVar _ _ timevar _) _ _ _ t t') -> case t' of
Nothing -> False
Just time -> t /= time )
regionlist
let result = map (\(ident, _ ) -> if (Map.member ident (wset stmState))
then (RTwiVar ident, Just $ WTwiVar ident)
else (RTwiVar ident, Nothing) )
inconslist
return $ Success stmState result
)
isInconsistent :: Region r a -> STM r p p Bool
isInconsistent reg@(Region ident _) = STM (\stmState -> do
let regionlist = filter (\ (_,(MkHeapEntry _ _ _ r _ _)) -> case r of
Just (Region ident' _) -> ident == ident' ) (Map.toList $ rset stmState)
let inconslist = filter (\(_,MkHeapEntry tvar@(TVar _ _ timevar _) _ _ _ t t') -> case t' of
Nothing -> False
Just time -> t /= time )
regionlist
return $ Success stmState (not $ null inconslist)
)
enterTwilight :: StmState r -> IO [IO()]
enterTwilight stmState = do
getNeg semaphore
success <- tryLockWriteset $ wset stmState
putNeg semaphore
case success of
Nothing -> enterTwilight stmState
Just locks -> return locks
lockWriteSet :: Heap r -> IO [IO ()]
lockWriteSet wset = do
success <- tryLockWriteset wset
case success of
Nothing -> lockWriteSet wset
Just locks -> return locks
tryLockWriteset :: Heap r -> IO (Maybe [IO ()])
tryLockWriteset wset = do
(success,unlocks) <- foldM (\(valid,unlock) (_,MkHeapEntry (TVar mv _ mt lck) val ident r _ _) ->
if valid
then do
l <- tryTakeMVar lck
case l of
Nothing -> return (False,unlock)
Just _ -> return (True,(putMVar lck () ):unlock)
else return (False,unlock)
)
(True,[])
(Map.toAscList wset)
if (not success)
then do
mapM (\x -> do x) unlocks
return Nothing
else do
return $ Just unlocks
validateReadSet :: StmState r -> IO (Bool,Heap r)
validateReadSet state =
foldM (\(valid,rset') (_,MkHeapEntry tvar@(TVar mv _ mt lck) val ident r t _) -> do
if (Map.member ident (wset state))
then do
globTime <- readIORef mt
return (valid && (globTime == t), Map.insert ident (MkHeapEntry tvar val ident r t (Just globTime)) rset')
else do
l <- tryTakeMVar lck
case l of
Nothing -> return (False,Map.insert ident (MkHeapEntry tvar val ident r t Nothing) rset')
Just _ -> do globTime <- readIORef mt
putMVar lck ()
return (valid && (globTime == t), Map.insert ident (MkHeapEntry tvar val ident r t (Just globTime)) rset')
)
(True,Map.empty)
(Map.toAscList $ rset state)
publishAndUnlock :: Heap r -> IO ()
publishAndUnlock wset = do
t_commit <- sampleClock
mapM_ (\(MkHeapEntry (TVar mv _ mt lck) val ident r _ _) -> do
writeIORef mv val
writeIORef mt t_commit
putMVar lck ()
)
wset
unlockLocks :: [IO ()] -> IO ()
unlockLocks unlocks =
mapM_ (\x -> do x) unlocks
data TVar a = TVar (IORef a)
Id
(IORef Timestamp)
(MVar ())
instance Show (TVar a) where
show (TVar _ ident _ _) = "(" ++ show ident ++ ")"
newTVarIO :: a -> IO (TVar a)
newTVarIO x = do
content <- newIORef x
ident <- getUniqueId
time <- newIORef (1)
lck <- newMVar ()
return (TVar content ident time lck)
data RTwiVar a = RTwiVar Id
data WTwiVar a = WTwiVar Id
data HeapEntry r = forall a. Show a => MkHeapEntry (TVar a) a Id (Maybe (Region r a)) Timestamp (Maybe Timestamp)
type Heap r = Map.IntMap (HeapEntry r)
data Region r a = Region Id Checked
newRegion :: STM r Atomic Atomic (Region r a)
newRegion = STM (\stmState -> do
let ident = regCounter stmState
return $ Success stmState{regCounter = ident+1} ( Region ident Cons))
globalClock :: Counter
globalClock = getCounter
sampleClock :: IO Int
sampleClock = getAndIncr globalClock
semaphore :: Semaphore
semaphore = getSemaphore
uniqueTVarId :: Counter
uniqueTVarId = getCounter
getUniqueId :: IO Int
getUniqueId = getAndIncr uniqueTVarId
main :: IO ()
main = do
q <- atomically $ gret 3
putStrLn "Empty transaction works\n"
let val = 89
y <- atomically $ newTVar val `gbind` \a -> gret a
putStrLn "Creating new TVars works\n"
x <- atomically $ newRegion `gbind` \r -> readTVar y r `gbind` \(u,_) -> gret u
putStrLn $ show x ++ " should be " ++ show val
let new_val = 25
z <- atomically $ writeTVar y new_val `gbind` \_ -> gret "done"
putStrLn z
let new_val = 34
z <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \_ -> readTVar y r `gbind` \(u,_) -> gret u
putStrLn $ show z ++ " should be " ++ show new_val
let new_val = 78
z <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \_ -> readTVar y r `gbind` \(u,_) -> twilight `gbind` \_ -> gret u
putStrLn $ show z ++ " should be " ++ show new_val
return ()