-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.AdvSTM.Class
-- Copyright   :  Peter Robinson 2008, HaskellWiki 2006-2007
-- License     :  BSD3
-- 
-- Maintainer  :  Peter Robinson <robinson@ecs.tuwien.ac.at>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Provides the type class MonadAdvSTM 
-- Parts of this implementation were taken from the HaskellWiki Page of
-- MonadAdvSTM (see package description).
-----------------------------------------------------------------------------

module Control.Monad.AdvSTM.Class( MonadAdvSTM(..), handleSTM, TVar(TVar), valueTVar, onCommitLock, currentTid)
where

import Control.Exception(Exception,throw)
import qualified Control.Concurrent.STM as S
import qualified Control.Concurrent.STM.TVar as OldTVar
import qualified Control.Concurrent.STM.TMVar as OldTMVar
import Control.Monad(Monad,liftM,ap)
import Control.Monad.Trans(lift)
import Control.Monad.State(StateT(StateT),mapStateT,runStateT,evalStateT)
import Control.Monad.Writer(WriterT(WriterT),mapWriterT,runWriterT,execWriterT)
import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT)
-- import Control.Monad.AdvSTM.Def(AdvSTM)
import Control.Concurrent( ThreadId )
--import GHC.Conc( unsafeIOToSTM )
-- import {-# SOURCE #-} Control.Concurrent.AdvSTM.TVar
import Data.Monoid

data TVar a = TVar 
    { valueTVar    :: OldTVar.TVar a     
    , onCommitLock :: OldTMVar.TMVar ()  
    , currentTid   :: OldTVar.TVar (Maybe ThreadId)
    }

-- | A type class for extended-STM monads. For a concrete instantiation see
-- 'AdvSTM'
class Monad m => MonadAdvSTM m where

    -- | Takes an IO action that will be executed /iff/ the transaction commits. 
    -- 
    -- * When a TVar was modified in a transaction and this transaction commits,
    -- this update remains invisible to other threads until the corresponding 
    -- onCommit action was run. 
    --
    -- * If the onCommit action throws an exception, the original value of 
    -- the TVars  will be restored.
    --
    -- * Accessing a modified TVar within the onCommit action will cause a
    -- Deadlock exception to be thrown. 
    --
    -- As a general rule, 'onCommit' should 
    -- only be used for \"real\" (i.e. without atomic blocks) IO actions and is certainly
    -- not the right place to fiddle with TVars. For example, if you wanted to
    -- write a TVar value to a file on commit, you could write:
    -- 
    -- > tvar <- newTVarIO "bla"
    -- > atomically $ do 
    -- >    x <- readTVar tvar 
    -- >    onCommit (writeFile "myfile" x)
    -- 
    --
    onCommit  :: IO () -> m ()

    -- | Adds an IO action to the retry job-queue. If the transaction retries,
    -- a new helper thread is forked that runs the retry actions, and, after the helper 
    -- thread is done, the transaction retries.
    -- 
    -- /Note:/ When the transaction is retried, 'unsafeIOToSTM' is used to fork a 
    -- helper thread that runs the retry actions (if any). It is your
    -- responsibility to ensure that your retry IO-actions are ``safe''. Any
    -- exceptions occurring in the retry-thread will be thrown to the
    -- thread where the transaction is running and immediately cause the transaction to be
    -- aborted, since 'catchSTM' does not catch asynchronous exceptions.
    onRetry :: IO () -- ^ IO action that will be run if the transaction is (explicitly) retried.
            -> m ()

    -- | See 'S.orElse'
    orElse :: m a -> m a -> m a

    -- | Runs any IO actions added by 'onRetry' and then retries the
    -- transaction.
    retry  :: m a

    -- | See 'S.check'
    check :: Bool -> m ()

    -- | See 'S.alwaysSucceeds'
    alwaysSucceeds :: m a -> m ()

    -- | See 'S.always'
    always :: m Bool -> m ()

    -- | See 'S.catchSTM'
    catchSTM  :: Exception e => m a -> (e -> m a) -> m a


    -- | Lifts STM actions to 'MonadAdvSTM'.
    liftAdv   :: S.STM a -> m a
 
    -- | Reads a value from a TVar. Blocks until the IO onCommit action(s) of 
    -- the corresponding transaction are complete.
    -- See 'onCommit' for a more detailed description of this behaviour.
    readTVar :: TVar a -> m a

    -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are
    -- complete. See 'onCommit' for details.
    writeTVar :: TVar a -> a -> m ()

    -- | See 'OldTVar.newTVar'
    newTVar :: a -> m (TVar a)

--    --  See 'S.atomically'
--    runAtomic :: m a -> IO a
--    newTVarIO :: a -> IO (TVar a)
    
-- | A version of 'catchSTM' with the arguments swapped around.
handleSTM :: (MonadAdvSTM m, Exception e) => (e -> m a) -> m a -> m a
handleSTM = flip catchSTM



--------------------------------------------------------------------------------


mapStateT2 :: (m (a, s) -> n (b, s) -> o (c,s)) 
           -> StateT s m a -> StateT s n b -> StateT s o c
mapStateT2 f m1 m2 = StateT $ \s -> f (runStateT m1 s) (runStateT m2 s)

liftAndSkipStateT f m = StateT $ \s -> let a = evalStateT m s
                                in do r <- f a
                                      return (r,s)

instance MonadAdvSTM m => MonadAdvSTM (StateT s m) where
  onCommit = lift . onCommit  

  onRetry  = lift . onRetry 

  orElse = mapStateT2 orElse

  retry  = lift retry

  check = lift . check 

  -- Note: The state modifications of the invariant action
  -- are thrown away!
  alwaysSucceeds = liftAndSkipStateT alwaysSucceeds 
  always         = liftAndSkipStateT always
  
  catchSTM m h = StateT (\r -> catchSTM (runStateT m r) (\e -> runStateT (h e) r))

  liftAdv = lift . liftAdv 
 
  readTVar = lift . readTVar

  writeTVar tvar = lift . writeTVar tvar

  newTVar = lift . newTVar 

--------------------------------------------------------------------------------

mapWriterT2 :: (m (a, w) -> n (b, w) -> o (c,w)) 
            -> WriterT w m a -> WriterT w n b -> WriterT w o c
mapWriterT2 f m1 m2 = WriterT $ f (runWriterT m1) (runWriterT m2)
-- mapWriterT2 f m1 m2 = liftM f m1 `ap` m2

evalWriterT :: Monad m => WriterT w m a -> m a
evalWriterT m = do
  (a,_) <- runWriterT m
  return a

liftAndSkipWriterT :: (Monad m,Monoid w)
            => (m a -> m b)
            -> WriterT w m a -> WriterT w m b
liftAndSkipWriterT f m = WriterT $ 
  let a = evalWriterT m
  in do r <- f a
        return (r,mempty)

instance (MonadAdvSTM m, Monoid w) => MonadAdvSTM (WriterT w m) where
  onCommit = lift . onCommit  

  onRetry  = lift . onRetry 

  orElse = mapWriterT2 orElse

  retry  = lift retry

  check = lift . check 

  -- Note: The writer-log modifications of the invariant action
  -- are thrown away!
  alwaysSucceeds = liftAndSkipWriterT alwaysSucceeds 
  always         = liftAndSkipWriterT always
  
  catchSTM m h = WriterT (catchSTM (runWriterT m) (\e -> runWriterT (h e)))

  liftAdv = lift . liftAdv 
 
  readTVar = lift . readTVar

  writeTVar tvar = lift . writeTVar tvar

  newTVar = lift . newTVar 

--------------------------------------------------------------------------------

mapReaderT2 :: (m a -> n b -> o c) -> ReaderT r m a -> ReaderT r n b -> ReaderT r o c
mapReaderT2 f m1 m2 = ReaderT $ \r -> f (runReaderT m1 r) (runReaderT m2 r) 

instance MonadAdvSTM m => MonadAdvSTM (ReaderT r m) where
  onCommit = lift . onCommit  

  onRetry  = lift . onRetry 

  orElse = mapReaderT2 orElse

  retry  = lift retry

  check = lift . check 

  alwaysSucceeds = mapReaderT alwaysSucceeds
  always         = mapReaderT always

  catchSTM m h = ReaderT (\r -> catchSTM (runReaderT m r) (\e -> runReaderT (h e) r))

  liftAdv = lift . liftAdv 
 
  readTVar = lift . readTVar

  writeTVar tvar = lift . writeTVar tvar

  newTVar = lift . newTVar 


--------------------------------------------------------------------------------