module Test.DejaFu.Conc.Internal.Memory where
import Control.Monad.Ref (MonadRef, readRef,
writeRef)
import Data.Map.Strict (Map)
import Data.Maybe (maybeToList)
import Data.Monoid ((<>))
import Data.Sequence (Seq, ViewL(..), singleton,
viewl, (><))
import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Threading
import qualified Data.Map.Strict as M
newtype WriteBuffer r = WriteBuffer
{ buffer :: Map (ThreadId, Maybe CRefId) (Seq (BufferedWrite r)) }
data BufferedWrite r where
BufferedWrite :: ThreadId -> CRef r a -> a -> BufferedWrite r
emptyBuffer :: WriteBuffer r
emptyBuffer = WriteBuffer M.empty
bufferWrite :: MonadRef r n => WriteBuffer r -> (ThreadId, Maybe CRefId) -> CRef r a -> a -> n (WriteBuffer r)
bufferWrite (WriteBuffer wb) k@(tid, _) cref@(CRef _ ref) new = do
let write = singleton $ BufferedWrite tid cref new
let buffer' = M.insertWith (flip (><)) k write wb
(locals, count, def) <- readRef ref
writeRef ref (M.insert tid new locals, count, def)
pure (WriteBuffer buffer')
commitWrite :: MonadRef r n => WriteBuffer r -> (ThreadId, Maybe CRefId) -> n (WriteBuffer r)
commitWrite w@(WriteBuffer wb) k = case maybe EmptyL viewl $ M.lookup k wb of
BufferedWrite _ cref a :< rest -> do
writeImmediate cref a
pure . WriteBuffer $ M.insert k rest wb
EmptyL -> pure w
readCRef :: MonadRef r n => CRef r a -> ThreadId -> n a
readCRef cref tid = do
(val, _) <- readCRefPrim cref tid
pure val
readForTicket :: MonadRef r n => CRef r a -> ThreadId -> n (Ticket a)
readForTicket cref@(CRef crid _) tid = do
(val, count) <- readCRefPrim cref tid
pure (Ticket crid count val)
casCRef :: MonadRef r n => CRef r a -> ThreadId -> Ticket a -> a -> n (Bool, Ticket a)
casCRef cref tid (Ticket _ cc _) !new = do
tick'@(Ticket _ cc' _) <- readForTicket cref tid
if cc == cc'
then do
writeImmediate cref new
tick'' <- readForTicket cref tid
pure (True, tick'')
else pure (False, tick')
readCRefPrim :: MonadRef r n => CRef r a -> ThreadId -> n (a, Integer)
readCRefPrim (CRef _ ref) tid = do
(vals, count, def) <- readRef ref
pure (M.findWithDefault def tid vals, count)
writeImmediate :: MonadRef r n => CRef r a -> a -> n ()
writeImmediate (CRef _ ref) a = do
(_, count, _) <- readRef ref
writeRef ref (M.empty, count + 1, a)
writeBarrier :: MonadRef r n => WriteBuffer r -> n ()
writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where
flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a
addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r
addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
phantoms = [ (ThreadId Nothing $ negate tid, mkthread c)
| ((_, b), tid) <- zip (M.toList wb) [1..]
, c <- maybeToList (go $ viewl b)
]
go (BufferedWrite tid (CRef crid _) _ :< _) = Just $ ACommit tid crid
go EmptyL = Nothing
delCommitThreads :: Threads n r -> Threads n r
delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread
data Blocking = Blocking | NonBlocking
data Emptying = Emptying | NonEmptying
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
putIntoMVar cvar a c = mutMVar Blocking cvar a (const c)
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryPutIntoMVar = mutMVar NonBlocking
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
readFromMVar cvar c = seeMVar NonEmptying Blocking cvar (c . efromJust "readFromMVar")
tryReadFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryReadFromMVar = seeMVar NonEmptying NonBlocking
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
takeFromMVar cvar c = seeMVar Emptying Blocking cvar (c . efromJust "takeFromMVar")
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryTakeFromMVar = seeMVar Emptying NonBlocking
mutMVar :: MonadRef r n
=> Blocking -> MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
mutMVar blocking (MVar cvid ref) a c threadid threads = do
val <- readRef ref
case val of
Just _ -> case blocking of
Blocking ->
let threads' = block (OnMVarEmpty cvid) threadid threads
in pure (False, threads', [])
NonBlocking ->
pure (False, goto (c False) threadid threads, [])
Nothing -> do
writeRef ref $ Just a
let (threads', woken) = wake (OnMVarFull cvid) threads
pure (True, goto (c True) threadid threads', woken)
seeMVar :: MonadRef r n
=> Emptying -> Blocking -> MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
seeMVar emptying blocking (MVar cvid ref) c threadid threads = do
val <- readRef ref
case val of
Just _ -> do
case emptying of
Emptying -> writeRef ref Nothing
NonEmptying -> pure ()
let (threads', woken) = wake (OnMVarEmpty cvid) threads
pure (True, goto (c val) threadid threads', woken)
Nothing -> case blocking of
Blocking ->
let threads' = block (OnMVarFull cvid) threadid threads
in pure (False, threads', [])
NonBlocking ->
pure (False, goto (c Nothing) threadid threads, [])