{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE TypeFamilies               #-}

-- | A 'MonadSTM' implementation, which can be run on top of 'IO' or
-- 'ST'.
module Test.DejaFu.STM
  ( -- * The @STMLike@ Monad
    STMLike
  , STMST
  , STMIO

  -- * Executing Transactions
  , Result(..)
  , TTrace
  , TAction(..)
  , TVarId
  , runTransactionST
  , runTransactionIO
  ) where

import Control.Monad (liftM)
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad.Cont (cont)
import Control.Monad.ST (ST)
import Data.IORef (IORef)
import Data.STRef (STRef)
import Test.DejaFu.Deterministic.Internal.Common (TVarId, IdSource, TAction(..), TTrace)
import Test.DejaFu.Internal
import Test.DejaFu.STM.Internal

import qualified Control.Monad.STM.Class as C

{-# ANN module ("HLint: ignore Use record patterns" :: String) #-}

newtype STMLike n r a = S { runSTM :: M n r a } deriving (Functor, Applicative, Monad)

-- | Create a new STM continuation.
toSTM :: ((a -> STMAction n r) -> STMAction n r) -> STMLike n r a
toSTM = S . cont

-- | A 'MonadSTM' implementation using @ST@, it encapsulates a single
-- atomic transaction. The environment, that is, the collection of
-- defined 'TVar's is implicit, there is no list of them, they exist
-- purely as references. This makes the types simpler, but means you
-- can't really get an aggregate of them (if you ever wanted to for
-- some reason).
type STMST t = STMLike (ST t) (STRef t)

-- | A 'MonadSTM' implementation using @ST@, it encapsulates a single
-- atomic transaction. The environment, that is, the collection of
-- defined 'TVar's is implicit, there is no list of them, they exist
-- purely as references. This makes the types simpler, but means you
-- can't really get an aggregate of them (if you ever wanted to for
-- some reason).
type STMIO = STMLike IO IORef

instance MonadThrow (STMLike n r) where
  throwM = toSTM . const . SThrow

instance MonadCatch (STMLike n r) where
  catch (S stm) handler = toSTM (SCatch (runSTM . handler) stm)

instance Monad n => C.MonadSTM (STMLike n r) where
  type TVar (STMLike n r) = TVar r

  retry = toSTM (const SRetry)

  orElse (S a) (S b) = toSTM (SOrElse a b)

  newTVarN n = toSTM . SNew n

  readTVar = toSTM . SRead

  writeTVar tvar a = toSTM (\c -> SWrite tvar a (c ()))

-- | Run a transaction in the 'ST' monad, returning the result and new
-- initial 'TVarId'. If the transaction ended by calling 'retry', any
-- 'TVar' modifications are undone.
runTransactionST :: STMST t a -> IdSource -> ST t (Result a, IdSource, TTrace)
runTransactionST = runTransactionM fixedST where
  fixedST = refST $ \mb -> cont (\c -> SLift $ c `liftM` mb)

-- | Run a transaction in the 'IO' monad, returning the result and new
-- initial 'TVarId'. If the transaction ended by calling 'retry', any
-- 'TVar' modifications are undone.
runTransactionIO :: STMIO a -> IdSource -> IO (Result a, IdSource, TTrace)
runTransactionIO = runTransactionM fixedIO where
  fixedIO = refIO $ \mb -> cont (\c -> SLift $ c `liftM` mb)

-- | Run a transaction in an arbitrary monad.
runTransactionM :: Monad n
  => Fixed n r -> STMLike n r a -> IdSource -> n (Result a, IdSource, TTrace)
runTransactionM ref ma tvid = do
  (res, undo, tvid', trace) <- doTransaction ref (runSTM ma) tvid

  case res of
    Success _ _ _ -> return (res, tvid', trace)
    _ -> undo >> return (res, tvid, trace)