{-# LANGUAGE CPP #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} -- Must come after TypeFamilies {-# LANGUAGE NoMonoLocalBinds #-} -- | -- Module : Test.DejaFu.Conc.Internal.STM -- Copyright : (c) 2017 Michael Walker -- License : MIT -- Maintainer : Michael Walker -- Stability : experimental -- Portability : CPP, ExistentialQuantification, MultiParamTypeClasses, NoMonoLocalBinds, TypeFamilies -- -- 'MonadSTM' testing implementation, internal types and definitions. -- This module is NOT considered to form part of the public interface -- of this library. module Test.DejaFu.Conc.Internal.STM where import Control.Applicative (Alternative(..)) import Control.Exception (Exception, SomeException, fromException, toException) import Control.Monad (MonadPlus(..)) import Control.Monad.Catch (MonadCatch(..), MonadThrow(..)) import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) import Data.List (nub) import qualified Control.Monad.STM.Class as C import Test.DejaFu.Internal import Test.DejaFu.Types #if MIN_VERSION_base(4,9,0) import qualified Control.Monad.Fail as Fail #endif -------------------------------------------------------------------------------- -- * The @S@ monad -- | The underlying monad is based on continuations over primitive -- actions. -- -- This is not @Cont@ because we want to give it a custom @MonadFail@ -- instance. newtype S n r a = S { runSTM :: (a -> STMAction n r) -> STMAction n r } instance Functor (S n r) where fmap f m = S $ \c -> runSTM m (c . f) instance Applicative (S n r) where pure x = S $ \c -> c x f <*> v = S $ \c -> runSTM f (\g -> runSTM v (c . g)) instance Monad (S n r) where return = pure m >>= k = S $ \c -> runSTM m (\x -> runSTM (k x) c) #if MIN_VERSION_base(4,9,0) fail = Fail.fail instance Fail.MonadFail (S n r) where #endif fail e = S $ \_ -> SThrow (MonadFailException e) instance MonadThrow (S n r) where throwM e = S $ \_ -> SThrow e instance MonadCatch (S n r) where catch stm handler = S $ SCatch handler stm instance Alternative (S n r) where a <|> b = S $ SOrElse a b empty = S $ const SRetry instance MonadPlus (S n r) instance C.MonadSTM (S n r) where type TVar (S n r) = TVar r newTVarN n = S . SNew n readTVar = S . SRead writeTVar tvar a = S $ \c -> SWrite tvar a (c ()) -------------------------------------------------------------------------------- -- * Primitive actions -- | STM transactions are represented as a sequence of primitive -- actions. data STMAction n r = forall a e. Exception e => SCatch (e -> S n r a) (S n r a) (a -> STMAction n r) | forall a. SRead (TVar r a) (a -> STMAction n r) | forall a. SWrite (TVar r a) a (STMAction n r) | forall a. SOrElse (S n r a) (S n r a) (a -> STMAction n r) | forall a. SNew String a (TVar r a -> STMAction n r) | forall e. Exception e => SThrow e | SRetry | SStop (n ()) -------------------------------------------------------------------------------- -- * @TVar@s -- | A 'TVar' is a tuple of a unique ID and the value contained. The -- ID is so that blocked transactions can be re-run when a 'TVar' they -- depend on has changed. newtype TVar r a = TVar (TVarId, r a) -------------------------------------------------------------------------------- -- * Output -- | The result of an STM transaction, along with which 'TVar's it -- touched whilst executing. data Result a = Success [TVarId] [TVarId] a -- ^ The transaction completed successfully, reading the first list -- 'TVar's and writing to the second. | Retry [TVarId] -- ^ The transaction aborted by calling 'retry', and read the -- returned 'TVar's. It should be retried when at least one of the -- 'TVar's has been mutated. | Exception SomeException -- ^ The transaction aborted by throwing an exception. deriving Show -------------------------------------------------------------------------------- -- * Execution -- | Run a transaction, returning the result and new initial 'TVarId'. -- If the transaction failed, any effects are undone. runTransaction :: MonadRef r n => S n r a -> IdSource -> n (Result a, IdSource, [TAction]) runTransaction ma tvid = do (res, _, tvid', trace) <- doTransaction ma tvid pure (res, tvid', trace) -- | Run a STM transaction, returning an action to undo its effects. -- -- If the transaction fails, its effects will automatically be undone, -- so the undo action returned will be @pure ()@. doTransaction :: MonadRef r n => S n r a -> IdSource -> n (Result a, n (), IdSource, [TAction]) doTransaction ma idsource = do (c, ref) <- runRefCont SStop (Just . Right) (runSTM ma) (idsource', undo, readen, written, trace) <- go ref c (pure ()) idsource [] [] [] res <- readRef ref case res of Just (Right val) -> pure (Success (nub readen) (nub written) val, undo, idsource', reverse trace) Just (Left exc) -> undo >> pure (Exception exc, pure (), idsource, reverse trace) Nothing -> undo >> pure (Retry $ nub readen, pure (), idsource, reverse trace) where go ref act undo nidsrc readen written sofar = do (act', undo', nidsrc', readen', written', tact) <- stepTrans act nidsrc let newIDSource = nidsrc' newAct = act' newUndo = undo' >> undo newReaden = readen' ++ readen newWritten = written' ++ written newSofar = tact : sofar case tact of TStop -> pure (newIDSource, newUndo, newReaden, newWritten, TStop:newSofar) TRetry -> do writeRef ref Nothing pure (newIDSource, newUndo, newReaden, newWritten, TRetry:newSofar) TThrow -> do writeRef ref (Just . Left $ case act of SThrow e -> toException e; _ -> undefined) pure (newIDSource, newUndo, newReaden, newWritten, TThrow:newSofar) _ -> go ref newAct newUndo newIDSource newReaden newWritten newSofar -- | Run a transaction for one step. stepTrans :: MonadRef r n => STMAction n r -> IdSource -> n (STMAction n r, n (), IdSource, [TVarId], [TVarId], TAction) stepTrans act idsource = case act of SCatch h stm c -> stepCatch h stm c SRead ref c -> stepRead ref c SWrite ref a c -> stepWrite ref a c SNew n a c -> stepNew n a c SOrElse a b c -> stepOrElse a b c SStop na -> stepStop na SThrow e -> pure (SThrow e, nothing, idsource, [], [], TThrow) SRetry -> pure (SRetry, nothing, idsource, [], [], TRetry) where nothing = pure () stepCatch h stm c = cases TCatch stm c (\trace -> pure (SRetry, nothing, idsource, [], [], TCatch trace Nothing)) (\trace exc -> case fromException exc of Just exc' -> transaction (TCatch trace . Just) (h exc') c Nothing -> pure (SThrow exc, nothing, idsource, [], [], TCatch trace Nothing)) stepRead (TVar (tvid, ref)) c = do val <- readRef ref pure (c val, nothing, idsource, [tvid], [], TRead tvid) stepWrite (TVar (tvid, ref)) a c = do old <- readRef ref writeRef ref a pure (c, writeRef ref old, idsource, [], [tvid], TWrite tvid) stepNew n a c = do let (idsource', tvid) = nextTVId n idsource ref <- newRef a let tvar = TVar (tvid, ref) pure (c tvar, nothing, idsource', [], [tvid], TNew tvid) stepOrElse a b c = cases TOrElse a c (\trace -> transaction (TOrElse trace . Just) b c) (\trace exc -> pure (SThrow exc, nothing, idsource, [], [], TOrElse trace Nothing)) stepStop na = do na pure (SStop na, nothing, idsource, [], [], TStop) cases tact stm onSuccess onRetry onException = do (res, undo, idsource', trace) <- doTransaction stm idsource case res of Success readen written val -> pure (onSuccess val, undo, idsource', readen, written, tact trace Nothing) Retry readen -> do (res', undo', idsource'', readen', written', trace') <- onRetry trace pure (res', undo', idsource'', readen ++ readen', written', trace') Exception exc -> onException trace exc transaction tact stm onSuccess = cases (\t _ -> tact t) stm onSuccess (\trace -> pure (SRetry, nothing, idsource, [], [], tact trace)) (\trace exc -> pure (SThrow exc, nothing, idsource, [], [], tact trace))