-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.AdvSTM.Class
-- Copyright   :  Peter Robinson 2008, HaskellWiki 2006-2007
-- License     :  BSD3
-- 
-- Maintainer  :  Peter Robinson <robinson@ecs.tuwien.ac.at>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Provides the type class MonadAdvSTM 
-- Parts of this implementation were taken from the HaskellWiki Page of
-- MonadAdvSTM (see package description).
-----------------------------------------------------------------------------

module Control.Monad.AdvSTM.Class( MonadAdvSTM(..), handleSTM, TVar(TVar), valueTVar, onCommitLock, currentTid)
where

import Control.Exception(Exception,throw)
import qualified Control.Concurrent.STM as S
import qualified Control.Concurrent.STM.TVar as OldTVar
import qualified Control.Concurrent.STM.TMVar as OldTMVar
import Control.Monad(Monad,liftM,ap)
import Control.Monad.Trans(lift)
import Control.Monad.State(StateT(StateT),mapStateT,runStateT,evalStateT)
import Control.Monad.Writer(WriterT(WriterT),mapWriterT,runWriterT,execWriterT)
import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT)
-- import Control.Monad.AdvSTM.Def(AdvSTM)
import Control.Concurrent( ThreadId )
--import GHC.Conc( unsafeIOToSTM )
-- import {-# SOURCE #-} Control.Concurrent.AdvSTM.TVar
import Data.Monoid

data TVar a = TVar 
    { valueTVar    :: OldTVar.TVar a     
    , onCommitLock :: OldTMVar.TMVar ()  
    , currentTid   :: OldTVar.TVar (Maybe ThreadId)
    }

-- | A type class for extended-STM monads. For a concrete instantiation see
-- 'AdvSTM'
class Monad m => MonadAdvSTM m where

    -- | Takes a closure IO action and a commit IO action. 
    -- The commit IO action will be executed /iff/ the transaction commits. 
    -- Commit actions are sequenced (within the same transaction), i.e.,
    --
    -- > onCommitWith id (putStr "hello")
    -- > onCommitWith id (putStr " world")
    --
    -- will print \"hello world\". 
    -- 
    -- The closure action is useful for encapsulating the commit actions,
    -- e.g., within a database transaction.
    -- The last call of onCommitWith in the transaction 
    -- is applied to the sequence of commit actions, i.e.:
    --
    -- > onCommitWith id (putStr "hello")
    -- > onCommitWith (\s -> do { putStrLn "start"; s; putStrLn "\nend"})  (putStr " world")
    --
    -- * When a TVar was modified in a transaction and the transaction tries to commit,
    -- this update remains invisible to other threads until the corresponding 
    -- onCommit action is dispatched. 
    --
    -- * If the onCommit action throws an exception, the original value of 
    -- the TVars  will be restored.
    --
    -- * Accessing a modified TVar within the onCommit action will cause a
    -- Deadlock exception to be thrown. 
    --
    -- As a general rule, 'onCommit' should 
    -- only be used for \"real\" (i.e. without atomic blocks) IO actions and is certainly
    -- not the right place to fiddle with TVars. For example, if you wanted to
    -- write a TVar value to a file on commit, you could write:
    -- 
    -- > tvar <- newTVarIO "bla"
    -- > atomically $ do 
    -- >    x <- readTVar tvar 
    -- >    onCommit (writeFile "myfile" x)
    -- 
    -- Note: If you /really/ need to access the 'TVar' within an onCommit action
    -- (e.g. to recover from an IO exception), you can use 'writeTVarAsync'.
    onCommitWith  :: ([IO ()] -> IO ()) -- ^ closure action
                  -> m ()
    -- | Works like 'onCommitWith' without closure action:
    -- 'onCommit = onCommitWith id'
    onCommit :: IO () -> m ()
--    onCommit = onCommitWith sequence_

    -- | Retries the transaction and uses 'unsafeIOToSTM' to fork off a 
    -- thread that runs the given IO action. Since a transaction might be rerun
    -- several times by the runtime system, it is your responsibility to 
    -- ensure that the IO-action is idempotent and releases all acquired locks.
    unsafeRetryWith :: IO () -> m b

    -- | See 'S.orElse'
    orElse :: m a -> m a -> m a

    -- | See 'S.retry'
    retry  :: m a

    -- | See 'S.check'
    check :: Bool -> m ()

    -- | See 'S.alwaysSucceeds'
    alwaysSucceeds :: m a -> m ()

    -- | See 'S.always'
    always :: m Bool -> m ()

    -- | See 'S.catchSTM'
    catchSTM  :: Exception e => m a -> (e -> m a) -> m a


    -- | Lifts STM actions to 'MonadAdvSTM'.
    liftAdv   :: S.STM a -> m a
 
    -- | Reads a value from a TVar. Blocks until the IO onCommit aidction(s) of 
    -- the corresponding transaction are complete.is not the last function
    -- See 'onCommit' for a more detailed description of this behaviour.
    readTVar :: TVar a -> m a

    -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are
    -- complete. See 'onCommit' for details.
    writeTVar :: TVar a -> a -> m ()

    -- | Reads a value directly from the TVar. Does not block when the
    -- onCommit actions aren't done yet. NOTE: Only use this function when
    -- you know what you're doing.
    readTVarAsync :: TVar a -> m a 

    -- | Writes a value directly to the TVar. Does not block when 
    -- onCommit actions aren't done yet. This function comes in handy for
    -- error recovery of exceptions that occur in onCommit.
    writeTVarAsync :: TVar a -> a -> m ()

    -- | See 'OldTVar.newTVar'
    newTVar :: a -> m (TVar a)

--    --  See 'S.atomically'
--    runAtomic :: m a -> IO aid
--    newTVarIO :: a -> IO (TVar a)
    unsafeIOToSTM :: IO a -> m a    
    
-- | A version of 'catchSTM' with the arguments swapped around.
handleSTM :: (MonadAdvSTM m, Exception e) => (e -> m a) -> m a -> m a
handleSTM = flip catchSTM



--------------------------------------------------------------------------------


mapStateT2 :: (m (a, s) -> n (b, s) -> o (c,s)) 
           -> StateT s m a -> StateT s n b -> StateT s o c
mapStateT2 f m1 m2 = StateT $ \s -> f (runStateT m1 s) (runStateT m2 s)

liftAndSkipStateT f m = StateT $ \s -> let a = evalStateT m s
                                in do r <- f a
                                      return (r,s)

instance MonadAdvSTM m => MonadAdvSTM (StateT s m) where

  onCommit ca = lift (onCommit ca)

  onCommitWith cc = lift (onCommitWith cc)

  unsafeRetryWith  = lift . unsafeRetryWith 

  orElse = mapStateT2 orElse

  retry  = lift retry

  check = lift . check 

  -- Note: The state modifications of the invariant action
  -- are thrown away!
  alwaysSucceeds = liftAndSkipStateT alwaysSucceeds 
  always         = liftAndSkipStateT always
  
  catchSTM m h = StateT (\r -> catchSTM (runStateT m r) (\e -> runStateT (h e) r))

  liftAdv = lift . liftAdv 
 
  readTVar = lift . readTVar

  writeTVar tvar = lift . writeTVar tvar

  readTVarAsync = lift . readTVarAsync

  writeTVarAsync tvar = lift . writeTVarAsync tvar

  newTVar = lift . newTVar 

  unsafeIOToSTM = lift . unsafeIOToSTM

--------------------------------------------------------------------------------

mapWriterT2 :: (m (a, w) -> n (b, w) -> o (c,w)) 
            -> WriterT w m a -> WriterT w n b -> WriterT w o c
mapWriterT2 f m1 m2 = WriterT $ f (runWriterT m1) (runWriterT m2)
-- mapWriterT2 f m1 m2 = liftM f m1 `ap` m2

evalWriterT :: Monad m => WriterT w m a -> m a
evalWriterT m = do
  (a,_) <- runWriterT m
  return a

liftAndSkipWriterT :: (Monad m,Monoid w)
            => (m a -> m b)
            -> WriterT w m a -> WriterT w m b
liftAndSkipWriterT f m = WriterT $ 
  let a = evalWriterT m
  in do r <- f a
        return (r,mempty)

instance (MonadAdvSTM m, Monoid w) => MonadAdvSTM (WriterT w m) where
  onCommit ca = lift (onCommit ca)

  onCommitWith cc = lift (onCommitWith cc)

  unsafeRetryWith  = lift . unsafeRetryWith 

  orElse = mapWriterT2 orElse

  retry  = lift retry

  check = lift . check 

  -- Note: The writer-log modifications of the invariant action
  -- are thrown away!
  alwaysSucceeds = liftAndSkipWriterT alwaysSucceeds 
  always         = liftAndSkipWriterT always
  
  catchSTM m h = WriterT (catchSTM (runWriterT m) (\e -> runWriterT (h e)))

  liftAdv = lift . liftAdv 
 
  readTVar = lift . readTVar

  writeTVar tvar = lift . writeTVar tvar

  readTVarAsync = lift . readTVarAsync

  writeTVarAsync tvar = lift . writeTVarAsync tvar

  newTVar = lift . newTVar 

  unsafeIOToSTM = lift . unsafeIOToSTM
--------------------------------------------------------------------------------

mapReaderT2 :: (m a -> n b -> o c) -> ReaderT r m a -> ReaderT r n b -> ReaderT r o c
mapReaderT2 f m1 m2 = ReaderT $ \r -> f (runReaderT m1 r) (runReaderT m2 r) 

instance MonadAdvSTM m => MonadAdvSTM (ReaderT r m) where
  onCommit ca = lift (onCommit ca)

  onCommitWith cc = lift (onCommitWith cc)

  unsafeRetryWith  = lift . unsafeRetryWith 

  orElse = mapReaderT2 orElse

  retry  = lift retry

  check = lift . check 

  alwaysSucceeds = mapReaderT alwaysSucceeds
  always         = mapReaderT always

  catchSTM m h = ReaderT (\r -> catchSTM (runReaderT m r) (\e -> runReaderT (h e) r))

  liftAdv = lift . liftAdv 
 

  writeTVarAsync tvar = lift . writeTVarAsync tvar

  readTVarAsync = lift . readTVarAsync

  writeTVar tvar = lift . writeTVar tvar

  readTVar = lift . readTVar

  newTVar = lift . newTVar 

  unsafeIOToSTM = lift . unsafeIOToSTM

--------------------------------------------------------------------------------