module Test.DejaFu.Conc.Internal.STM where
import Control.Applicative (Alternative(..))
import Control.Exception (Exception, SomeException,
fromException, toException)
import Control.Monad (MonadPlus(..))
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
import Data.List (nub)
import qualified Control.Monad.STM.Class as C
import Test.DejaFu.Internal
import Test.DejaFu.Types
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
newtype S n r a = S { runSTM :: (a -> STMAction n r) -> STMAction n r }
instance Functor (S n r) where
fmap f m = S $ \c -> runSTM m (c . f)
instance Applicative (S n r) where
pure x = S $ \c -> c x
f <*> v = S $ \c -> runSTM f (\g -> runSTM v (c . g))
instance Monad (S n r) where
return = pure
m >>= k = S $ \c -> runSTM m (\x -> runSTM (k x) c)
#if MIN_VERSION_base(4,9,0)
fail = Fail.fail
instance Fail.MonadFail (S n r) where
#endif
fail e = S $ \_ -> SThrow (MonadFailException e)
instance MonadThrow (S n r) where
throwM e = S $ \_ -> SThrow e
instance MonadCatch (S n r) where
catch stm handler = S $ SCatch handler stm
instance Alternative (S n r) where
a <|> b = S $ SOrElse a b
empty = S $ const SRetry
instance MonadPlus (S n r)
instance C.MonadSTM (S n r) where
type TVar (S n r) = TVar r
newTVarN n = S . SNew n
readTVar = S . SRead
writeTVar tvar a = S $ \c -> SWrite tvar a (c ())
data STMAction n r
= forall a e. Exception e => SCatch (e -> S n r a) (S n r a) (a -> STMAction n r)
| forall a. SRead (TVar r a) (a -> STMAction n r)
| forall a. SWrite (TVar r a) a (STMAction n r)
| forall a. SOrElse (S n r a) (S n r a) (a -> STMAction n r)
| forall a. SNew String a (TVar r a -> STMAction n r)
| forall e. Exception e => SThrow e
| SRetry
| SStop (n ())
newtype TVar r a = TVar (TVarId, r a)
data Result a =
Success [TVarId] [TVarId] a
| Retry [TVarId]
| Exception SomeException
deriving Show
runTransaction :: MonadRef r n
=> S n r a
-> IdSource
-> n (Result a, IdSource, [TAction])
runTransaction ma tvid = do
(res, _, tvid', trace) <- doTransaction ma tvid
pure (res, tvid', trace)
doTransaction :: MonadRef r n
=> S n r a
-> IdSource
-> n (Result a, n (), IdSource, [TAction])
doTransaction ma idsource = do
(c, ref) <- runRefCont SStop (Just . Right) (runSTM ma)
(idsource', undo, readen, written, trace) <- go ref c (pure ()) idsource [] [] []
res <- readRef ref
case res of
Just (Right val) -> pure (Success (nub readen) (nub written) val, undo, idsource', reverse trace)
Just (Left exc) -> undo >> pure (Exception exc, pure (), idsource, reverse trace)
Nothing -> undo >> pure (Retry $ nub readen, pure (), idsource, reverse trace)
where
go ref act undo nidsrc readen written sofar = do
(act', undo', nidsrc', readen', written', tact) <- stepTrans act nidsrc
let newIDSource = nidsrc'
newAct = act'
newUndo = undo' >> undo
newReaden = readen' ++ readen
newWritten = written' ++ written
newSofar = tact : sofar
case tact of
TStop -> pure (newIDSource, newUndo, newReaden, newWritten, TStop:newSofar)
TRetry -> do
writeRef ref Nothing
pure (newIDSource, newUndo, newReaden, newWritten, TRetry:newSofar)
TThrow -> do
writeRef ref (Just . Left $ case act of SThrow e -> toException e; _ -> undefined)
pure (newIDSource, newUndo, newReaden, newWritten, TThrow:newSofar)
_ -> go ref newAct newUndo newIDSource newReaden newWritten newSofar
stepTrans :: MonadRef r n
=> STMAction n r
-> IdSource
-> n (STMAction n r, n (), IdSource, [TVarId], [TVarId], TAction)
stepTrans act idsource = case act of
SCatch h stm c -> stepCatch h stm c
SRead ref c -> stepRead ref c
SWrite ref a c -> stepWrite ref a c
SNew n a c -> stepNew n a c
SOrElse a b c -> stepOrElse a b c
SStop na -> stepStop na
SThrow e -> pure (SThrow e, nothing, idsource, [], [], TThrow)
SRetry -> pure (SRetry, nothing, idsource, [], [], TRetry)
where
nothing = pure ()
stepCatch h stm c = cases TCatch stm c
(\trace -> pure (SRetry, nothing, idsource, [], [], TCatch trace Nothing))
(\trace exc -> case fromException exc of
Just exc' -> transaction (TCatch trace . Just) (h exc') c
Nothing -> pure (SThrow exc, nothing, idsource, [], [], TCatch trace Nothing))
stepRead (TVar (tvid, ref)) c = do
val <- readRef ref
pure (c val, nothing, idsource, [tvid], [], TRead tvid)
stepWrite (TVar (tvid, ref)) a c = do
old <- readRef ref
writeRef ref a
pure (c, writeRef ref old, idsource, [], [tvid], TWrite tvid)
stepNew n a c = do
let (idsource', tvid) = nextTVId n idsource
ref <- newRef a
let tvar = TVar (tvid, ref)
pure (c tvar, nothing, idsource', [], [tvid], TNew tvid)
stepOrElse a b c = cases TOrElse a c
(\trace -> transaction (TOrElse trace . Just) b c)
(\trace exc -> pure (SThrow exc, nothing, idsource, [], [], TOrElse trace Nothing))
stepStop na = do
na
pure (SStop na, nothing, idsource, [], [], TStop)
cases tact stm onSuccess onRetry onException = do
(res, undo, idsource', trace) <- doTransaction stm idsource
case res of
Success readen written val -> pure (onSuccess val, undo, idsource', readen, written, tact trace Nothing)
Retry readen -> do
(res', undo', idsource'', readen', written', trace') <- onRetry trace
pure (res', undo', idsource'', readen ++ readen', written', trace')
Exception exc -> onException trace exc
transaction tact stm onSuccess = cases (\t _ -> tact t) stm onSuccess
(\trace -> pure (SRetry, nothing, idsource, [], [], tact trace))
(\trace exc -> pure (SThrow exc, nothing, idsource, [], [], tact trace))