-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.AdvSTM
-- Copyright   :  (c) Chris Kuklewicz 2006, Peter Robinson 2009
-- License     :  BSD3
-- 
-- Maintainer  :  Peter Robinson <robinson@ecs.tuwien.ac.at>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Extends Control.Concurrent.STM with IO hooks
-- 
-----------------------------------------------------------------------------
 
module Control.Concurrent.AdvSTM( -- * Class MonadAdvSTM
                                  MonadAdvSTM( onCommit
                                            , unsafeRetryWith
                                            , orElse
                                            , retry
                                            , check
                                            , alwaysSucceeds
                                            , always
                                            , catchSTM
                                            , liftAdv
                                            , newTVar
                                            , readTVarAsync
                                            , writeTVarAsync
                                            , readTVar
                                            , writeTVar
                                            )
                                
                                  -- * Monad AdvSTM
                                , AdvSTM
                                , atomically
                                , unsafeIOToAdvSTM
                                , handleSTM
                                , debugAdvSTM
                                , debugMode
                                ) where
 
import Prelude hiding (catch)
import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..))
import Control.Monad.AdvSTM.Class( 
    MonadAdvSTM(..), handleSTM, TVar( TVar ), onCommitLock, currentTid, valueTVar )
-- import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT)

import Control.Exception(Exception,throw,catch,SomeException,fromException,try,block,Deadlock(..),finally) 
import Control.Monad(mplus,when,liftM,ap,unless)
import Control.Monad.Error(MonadError(..))
import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId,throwTo)
import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan)
import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar)
-- import Control.Concurrent.STM.TChan(TChan,writeTChan)
import Control.Concurrent.MVar(MVar,newMVar,takeMVar,tryTakeMVar,putMVar,tryPutMVar,swapMVar)
import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check,always,alwaysSucceeds)
import qualified Control.Concurrent.STM.TVar as STVar -- (TVar,newTVarIO,readTVar,writeTVar)
import GHC.Conc(unsafeIOToSTM)
import Data.IORef(newIORef,readIORef,writeIORef)
import Data.Maybe(isJust,Maybe,fromJust)
import Control.Monad.Reader(MonadReader(..),ReaderT(ReaderT),runReaderT,lift,asks)
 
--------------------------------------------------------------------------------


instance MonadAdvSTM AdvSTM where
    onCommit ioaction = do
        commitVar     <- AdvSTM $ asks commitTVar
        commitActions <- liftAdv $ STVar.readTVar commitVar
        liftAdv $ STVar.writeTVar commitVar $ ioaction : commitActions

--    unsafeOnRetry ioaction = do
--        retryVar <- AdvSTM $ asks retryMVar
--        liftAdv . unsafeIOToSTM $ do
--          may'retryFun <- tryTakeMVar retryVar
--          let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun
--          putMVar retryVar $! retryFun

    orElse = mplus

    retry = liftAdv S.retry 

    check = liftAdv . S.check

    alwaysSucceeds inv = unlift inv >>= liftAdv . S.alwaysSucceeds 

    always inv = unlift inv >>= liftAdv . S.always 

    catchSTM action handler = do
        action'  <- unlift action
        handler' <- unlift1 handler
        let handler'' e = case fromException e of
                            Nothing -> throw e
                            Just e' -> handler' e'
        liftAdv $ S.catchSTM action' handler''

    liftAdv = AdvSTM . lift 


    -- | See 'STVar.newTVar'
    newTVar a = TVar `liftM` liftAdv (STVar.newTVar a) 
                     `ap`    liftAdv (newTMVar ()) 
                     `ap`    liftAdv (STVar.newTVar Nothing)


    -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are
    -- complete. See 'onCommit' for details.
    writeTVar tvar a = do 
        commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
        -- Get ThreadID of current transaction:
        curTid     <- AdvSTM $ asks transThreadId   
        storedTid  <- liftAdv $ STVar.readTVar (currentTid tvar) 
        case commitLock of
            Nothing -> 
                if isJust storedTid && fromJust storedTid == curTid
                  then throw Deadlock       -- Can't write the TVar in the onCommit phase
                  else retry
            Just _  -> do
                unless (isJust storedTid && (fromJust storedTid == curTid)) $ do
                    -- First write access, update the ThreadId:
                    liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid
                    -- Add this TVar to the onCommit-listener list:
                    lsTVar <- AdvSTM $ asks listeners
                    ls     <- liftAdv $ STVar.readTVar lsTVar
                    -- Remember the old value for rollback:
                    oldval <- liftAdv $ STVar.readTVar (valueTVar tvar)
                    liftAdv $ STVar.writeTVar lsTVar $
                            (onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls
                
                liftAdv $ STVar.writeTVar (valueTVar tvar) a 
                liftAdv $ putTMVar (onCommitLock tvar) ()


    writeTVarAsync tvar = liftAdv . STVar.writeTVar (valueTVar tvar) 

    --------------------------------------------------------------------------------

    -- | Reads a value from a TVar. Blocks until the IO onCommit action(s) of 
    -- the corresponding transaction are complete.
    -- See 'onCommit' for a more detailed description of this behaviour.
    readTVar tvar = do 
        commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
        case commitLock of
            Nothing -> do
                storedTid <- liftAdv $ STVar.readTVar (currentTid tvar) 
                curTid    <- AdvSTM $ asks transThreadId   
                if isJust storedTid && fromJust storedTid == curTid
                  then throw Deadlock
                  else retry
            Just _ -> do
                result <- liftAdv $ STVar.readTVar $ valueTVar tvar
                liftAdv $ putTMVar (onCommitLock tvar) ()
                return result


    readTVarAsync = liftAdv . STVar.readTVar . valueTVar 


    -- | Forks a separate thread to run the IO action and then retries the transaction.
    unsafeRetryWith io = do -- unsafeOnRetry io >> retry
      doneMVar <- AdvSTM $ asks retryDoneMVar
      unsafeIOToAdvSTM $ forkIO $ (do
        val <- takeMVar doneMVar  
        case val of
          Nothing -> return ()
          Just _  -> io 
        ) `finally` tryPutMVar doneMVar (Just ())
      retry
 
--------------------------------------------------------------------------------


-- | See 'S.atomically'
atomically :: AdvSTM a -> IO a
atomically (AdvSTM action) = do
    let debugging = False -- Switching this to True occasionally causes deadlocks (b/c unsafeIO)!
    debugTVar <- STVar.newTVarIO debugging
    tid       <- myThreadId       -- ThreadId for controlling TVar access
    debug "***************************************************************" 0 debugging
    debug (show (tid,"Starting transaction...")) 1000000 debugging
    -- Building the Reader monad environment
    commitVar <- STVar.newTVarIO []     -- IO actions to be run if the transaction commits
    retryDoneMVar   <- newMVar $ Just ()
    commitListeners <- STVar.newTVarIO []  
    let env = Env { commitTVar    = commitVar       :: STVar.TVar [IO ()] -- -> IO ())
                  , retryDoneMVar = retryDoneMVar   :: MVar (Maybe ()) -- (IO () -> IO ())
                  , transThreadId = tid             :: ThreadId
                  , listeners     = commitListeners :: STVar.TVar [(TMVar (),TVarValue)] 
                  , debugModeVar  = debugTVar
                  } 

    let stopRetryWith = swapMVar retryDoneMVar Nothing    

    let wrappedAction = runReaderT action env -- `S.orElse` check'retry

    -- Block exceptions from other threads for the rest of 'atomically'
    block $ do 
        result <- S.atomically $ do 
            debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging
            result <- wrappedAction
            debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging
            ls     <- STVar.readTVar commitListeners 
            -- Notify the TPVars that we're entering onCommit mode:
            debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging
            mapM_ (\(l,_) -> 
                    takeTMVar l    -- tell the TVar that we're going into onCommit mode
                  ) ls
            return result

        let rollbackOnCommit = do
                debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging
                S.atomically $ do
                    ls <- STVar.readTVar commitListeners
                    mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do 
                            STVar.writeTVar oldValTVar oldVal  -- smthg went wrong, restore the old value 
                            putTMVar l ()                -- ...and unblock the TVar
                          ) ls
     
        -- Wait for the retryWith thread(s) to be done before running the onCommit
        -- actions.
        stopRetryWith

        -- Now try to run the onCommit IO actions:
        commitAcs <- S.atomically $ STVar.readTVar commitVar
        sequence_ [ ac | ac <- commitAcs ] `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e)

        -- Everything's ok, so notify and unblock the TVars:
        debug (show (tid,"*********************")) 1000000 debugging
        debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging
        S.atomically $ do 
            ls <- STVar.readTVar commitListeners 
            mapM_ (\(l,_) -> 
                    putTMVar l ()                  
                  ) ls
        debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging
        debug "************************************************************" 0 debugging

        return result

--------------------------------------------------------------------------------

-- | See 'unsafeIOToSTM'
unsafeIOToAdvSTM :: IO a -> AdvSTM a
unsafeIOToAdvSTM = liftAdv . unsafeIOToSTM

--------------------------------------------------------------------------------

-- | Switches the debug mode on or off.
-- /WARNING:/ Can lead to deadlocks! 
debugMode :: Bool -> AdvSTM ()
debugMode switch = do
    debVar <- AdvSTM $ asks debugModeVar 
    liftAdv $ STVar.writeTVar debVar switch

-- | Uses unsafeIOToSTM to output the Thread Id and a message and delays 
-- for a given number of time.
-- /WARNING:/ Can lead to deadlocks! 
debugAdvSTM :: String -> Int -> AdvSTM ()
debugAdvSTM msg delay = do 
    debVar <- AdvSTM $ asks debugModeVar
    debugging <- liftAdv $ STVar.readTVar debVar
    tid    <- AdvSTM $ asks transThreadId 
    liftAdv $ debugSTM (show (tid,msg)) delay debugging 

--------------------------------------------------------------------------------
-- The following functions are not exported as they aren't needed by 
-- the enduser 

runWith :: Env -> AdvSTM t -> S.STM t
runWith env (AdvSTM action) = runReaderT action env
                            
 
unlifter :: AdvSTM (AdvSTM a -> S.STM a)
unlifter = do
    env <- AdvSTM ask
    return (runWith env)
 
unlift :: AdvSTM a -> AdvSTM (S.STM a)
unlift f = do
    u <- unlifter
    return (u f)
 
unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a)
unlift1 f = do
    u <- unlifter
    return (u . f)
 
-- WARNING: Can lead to deadlocks! 
debugSTM :: String -> Int -> Bool -> S.STM ()
debugSTM msg delay debugging = 
    when debugging $ unsafeIOToSTM $ putStrLn msg >> threadDelay delay

-- WARNING: Can lead to deadlocks! 
debug :: String -> Int -> Bool -> IO ()
debug msg delay debugging = when debugging $ putStrLn msg >> threadDelay delay

instance (Exception e) => MonadError e AdvSTM where
  throwError = throw
  catchError = catchSTM