{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Test.DejaFu.Conc.Internal.Program where
import Control.Applicative (Applicative(..))
import Control.Exception (MaskingState(..))
import qualified Control.Monad.Catch as Ca
import qualified Control.Monad.IO.Class as IO
import Control.Monad.Trans.Class (MonadTrans(..))
import qualified Data.Foldable as F
import Data.List (partition)
import qualified Data.Map.Strict as M
import Data.Maybe (isNothing)
import GHC.Stack (HasCallStack)
import qualified Control.Monad.Conc.Class as C
import Test.DejaFu.Conc.Internal
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.STM (ModelSTM)
import Test.DejaFu.Conc.Internal.Threading (Threads, _blocking)
import Test.DejaFu.Internal
import Test.DejaFu.Schedule
import Test.DejaFu.Types
instance (pty ~ Basic, IO.MonadIO n) => IO.MonadIO (Program pty n) where
liftIO ma = ModelConc (\c -> ALift (fmap c (IO.liftIO ma)))
instance (pty ~ Basic) => MonadTrans (Program pty) where
lift ma = ModelConc (\c -> ALift (fmap c ma))
instance (pty ~ Basic) => Ca.MonadCatch (Program pty n) where
catch ma h = ModelConc (ACatching h ma)
instance (pty ~ Basic) => Ca.MonadThrow (Program pty n) where
throwM e = ModelConc (\_ -> AThrow e)
instance (pty ~ Basic) => Ca.MonadMask (Program pty n) where
mask mb = ModelConc (AMasking MaskedInterruptible (\f -> mb f))
uninterruptibleMask mb = ModelConc (AMasking MaskedUninterruptible (\f -> mb f))
#if MIN_VERSION_exceptions(0,10,0)
generalBracket acquire release use = Ca.mask $ \unmasked -> do
resource <- acquire
b <- unmasked (use resource) `Ca.catch` (\e -> release resource (Ca.ExitCaseException e) >> Ca.throwM e)
c <- release resource (Ca.ExitCaseSuccess b)
pure (b, c)
#elif MIN_VERSION_exceptions(0,9,0)
generalBracket acquire release cleanup use = Ca.mask $ \unmasked -> do
resource <- acquire
result <- unmasked (use resource) `Ca.catch` (\e -> cleanup resource e >> Ca.throwM e)
_ <- release resource
pure result
#endif
instance (pty ~ Basic, Monad n) => C.MonadConc (Program pty n) where
type MVar (Program pty n) = ModelMVar n
type IORef (Program pty n) = ModelIORef n
type Ticket (Program pty n) = ModelTicket
type STM (Program pty n) = ModelSTM n
type ThreadId (Program pty n) = ThreadId
forkWithUnmaskN n ma = ModelConc (AFork n (\umask -> runModelConc (ma umask) (\_ -> AStop (pure ()))))
forkOnWithUnmaskN n _ = C.forkWithUnmaskN n
forkOSWithUnmaskN n ma = ModelConc (AForkOS n (\umask -> runModelConc (ma umask) (\_ -> AStop (pure ()))))
supportsBoundThreads = ModelConc ASupportsBoundThreads
isCurrentThreadBound = ModelConc AIsBound
getNumCapabilities = ModelConc AGetNumCapabilities
setNumCapabilities caps = ModelConc (\c -> ASetNumCapabilities caps (c ()))
myThreadId = ModelConc AMyTId
yield = ModelConc (\c -> AYield (c ()))
threadDelay n = ModelConc (\c -> ADelay n (c ()))
newIORefN n a = ModelConc (ANewIORef n a)
readIORef ref = ModelConc (AReadIORef ref)
readForCAS ref = ModelConc (AReadIORefCas ref)
peekTicket' _ = ticketVal
writeIORef ref a = ModelConc (\c -> AWriteIORef ref a (c ()))
casIORef ref tick a = ModelConc (ACasIORef ref tick a)
atomicModifyIORef ref f = ModelConc (AModIORef ref f)
modifyIORefCAS ref f = ModelConc (AModIORefCas ref f)
newEmptyMVarN n = ModelConc (ANewMVar n)
putMVar var a = ModelConc (\c -> APutMVar var a (c ()))
readMVar var = ModelConc (AReadMVar var)
takeMVar var = ModelConc (ATakeMVar var)
tryPutMVar var a = ModelConc (ATryPutMVar var a)
tryReadMVar var = ModelConc (ATryReadMVar var)
tryTakeMVar var = ModelConc (ATryTakeMVar var)
throwTo tid e = ModelConc (\c -> AThrowTo tid e (c ()))
getMaskingState = ModelConc (\c -> AGetMasking c)
atomically = ModelConc . AAtom
runConcurrent :: MonadDejaFu n
=> Scheduler s
-> MemType
-> s
-> Program pty n a
-> n (Either Condition a, s, Trace)
runConcurrent sched memtype s ma@(ModelConc _) = do
res <- runConcurrency [] False sched memtype s initialIdSource 2 ma
out <- efromJust <$> readRef (finalRef res)
pure ( out
, cSchedState (finalContext res)
, F.toList (finalTrace res)
)
runConcurrent sched memtype s ma = recordSnapshot ma >>= \case
Just (Left cond, trc) -> pure (Left cond, s, trc)
Just (Right snap, _) -> runSnapshot sched memtype s snap
Nothing -> fatal "failed to record snapshot!"
recordSnapshot
:: MonadDejaFu n
=> Program pty n a
-> n (Maybe (Either Condition (Snapshot pty n a), Trace))
recordSnapshot ModelConc{..} = pure Nothing
recordSnapshot WithSetup{..} =
let mkSnapshot snap _ = WS snap
in defaultRecordSnapshot mkSnapshot wsSetup wsProgram
recordSnapshot WithSetupAndTeardown{..} =
let mkSnapshot snap = WSAT snap . wstTeardown
in defaultRecordSnapshot mkSnapshot wstSetup wstProgram
runSnapshot
:: MonadDejaFu n
=> Scheduler s
-> MemType
-> s
-> Snapshot pty n a
-> n (Either Condition a, s, Trace)
runSnapshot sched memtype s (WS SimpleSnapshot{..}) = do
let context = fromSnapContext s snapContext
CResult{..} <- runConcurrencyWithSnapshot sched memtype context snapRestore snapNext
out <- efromJust <$> readRef finalRef
pure ( out
, cSchedState finalContext
, F.toList finalTrace
)
runSnapshot sched memtype s (WSAT SimpleSnapshot{..} teardown) = do
let context = fromSnapContext s snapContext
intermediateResult <- runConcurrencyWithSnapshot sched memtype context snapRestore snapNext
let idsrc = cIdSource (finalContext intermediateResult)
out1 <- efromJust <$> readRef (finalRef intermediateResult)
teardownResult <- simpleRunConcurrency False idsrc (teardown out1)
out2 <- efromJust <$> readRef (finalRef teardownResult)
pure ( out2
, cSchedState (finalContext intermediateResult)
, F.toList (finalTrace intermediateResult)
)
data Snapshot pty n a where
WS :: SimpleSnapshot n a -> Snapshot (WithSetup x) n a
WSAT :: SimpleSnapshot n a -> (Either Condition a -> ModelConc n y) -> Snapshot (WithSetupAndTeardown x a) n y
data SimpleSnapshot n a = SimpleSnapshot
{ snapContext :: Context n ()
, snapRestore :: Threads n -> n ()
, snapNext :: ModelConc n a
}
contextFromSnapshot :: Snapshot p n a -> Context n ()
contextFromSnapshot (WS SimpleSnapshot{..}) = snapContext
contextFromSnapshot (WSAT SimpleSnapshot{..} _) = snapContext
threadsFromSnapshot :: Snapshot p n a -> ([ThreadId], [ThreadId])
threadsFromSnapshot snap = (initialThread : runnable, blocked) where
(runnable, blocked) = partition isRunnable (M.keys threads)
threads = cThreads (contextFromSnapshot snap)
isRunnable tid = isNothing (_blocking =<< M.lookup tid threads)
defaultRecordSnapshot :: MonadDejaFu n
=> (SimpleSnapshot n a -> x -> snap)
-> ModelConc n x
-> (x -> ModelConc n a)
-> n (Maybe (Either Condition snap, Trace))
defaultRecordSnapshot mkSnapshot setup program = do
CResult{..} <- simpleRunConcurrency True initialIdSource setup
let trc = F.toList finalTrace
out <- readRef finalRef
pure . Just $ case out of
Just (Right a) ->
let snap = mkSnapshot (SimpleSnapshot finalContext finalRestore (program a)) a
in (Right snap, trc)
Just (Left f) -> (Left f, trc)
Nothing -> fatal "failed to produce snapshot"
simpleRunConcurrency ::(MonadDejaFu n, HasCallStack)
=> Bool
-> IdSource
-> ModelConc n a
-> n (CResult n () a)
simpleRunConcurrency forSnapshot idsrc =
runConcurrency [] forSnapshot roundRobinSchedNP SequentialConsistency () idsrc 2
fromSnapContext :: g -> Context n s -> Context n g
fromSnapContext g ctx@Context{..} = ctx
{ cSchedState = g
, cInvariants = InvariantContext
{ icActive = cNewInvariants
, icBlocked = []
}
, cNewInvariants = []
}
wrap :: (((a -> Action n) -> Action n) -> ((a -> Action n) -> Action n)) -> ModelConc n a -> ModelConc n a
wrap f = ModelConc . f . runModelConc