-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.AdvSTM
-- Copyright   :  (c) Chris Kuklewicz 2006, Peter Robinson 2009
-- License     :  BSD3
-- 
-- Maintainer  :  Peter Robinson <robinson@ecs.tuwien.ac.at>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Extends Control.Concurrent.STM with IO hooks
-- 
-----------------------------------------------------------------------------
 
module Control.Concurrent.AdvSTM( -- * Class MonadAdvSTM
                                  MonadAdvSTM(..)
                                  -- * Monad AdvSTM
                                , AdvSTM
                                , unsafeRetryWith
                                , atomically
                                , unsafeIOToAdvSTM
                                , debugAdvSTM
                                , debugMode
                                ) where
 
import Prelude hiding (catch)
import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..))
import Control.Monad.AdvSTM.Class(MonadAdvSTM(..))

import Control.Exception(throw,catch,SomeException,fromException,try) 
import Control.Monad(mplus,when)
import Control.Monad.Reader(MonadReader(..),ReaderT,runReaderT,lift,asks)
import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId)
import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan)
import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar)
-- import Control.Concurrent.STM.TChan(TChan,writeTChan)
import Control.Concurrent.MVar(MVar,newEmptyMVar,takeMVar,tryTakeMVar,putMVar)
import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check,always,alwaysSucceeds)
import Control.Concurrent.STM.TVar(TVar,newTVarIO,readTVar,writeTVar)
import GHC.Conc(unsafeIOToSTM)
import Data.IORef(newIORef,readIORef,writeIORef)
 
--------------------------------------------------------------------------------


instance MonadAdvSTM AdvSTM where
    onCommit ioaction = do
        commitVar <- AdvSTM $ asks commitTVar
        commitFun <- liftAdv $ readTVar commitVar
        liftAdv $ writeTVar commitVar $ commitFun . (ioaction >>)

    unsafeOnRetry ioaction = do
        retryVar <- AdvSTM $ asks retryMVar
        liftAdv . unsafeIOToSTM $ do
          may'retryFun <- tryTakeMVar retryVar
          let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun
          putMVar retryVar $! retryFun

    orElse = mplus

    retry = liftAdv S.retry 

    check = liftAdv . S.check
    

    alwaysSucceeds inv = unlift inv >>= liftAdv . S.alwaysSucceeds 

    always inv = unlift inv >>= liftAdv . S.always 

    runAtomic = atomically

    catchSTM action handler = do
        action'  <- unlift action
        handler' <- unlift1 handler
        let handler'' = \e -> case fromException e of
                                    Nothing -> throw e
                                    Just e' -> handler' e'
        liftAdv $ S.catchSTM action' (handler'') 

    liftAdv = AdvSTM . lift 

-- | Adds the IO action to the retry queue and then retries the transaction
unsafeRetryWith :: (Monad m, MonadAdvSTM m) => IO () -> m b
unsafeRetryWith io = unsafeOnRetry io >> retry
 

 
--------------------------------------------------------------------------------



atomically :: AdvSTM a -> IO a
atomically (AdvSTM action) = do
    let debugging = False -- Switching this to True occasionally causes deadlocks (b/c unsafeIO)!
    debugTVar <- newTVarIO debugging
    tid       <- myThreadId       -- ThreadId for controlling TVar access
    debug "***************************************************************" 0 debugging
    debug (show (tid,"Starting transaction...")) 1000000 debugging
    -- Building the Reader monad environment
    commitVar <- newTVarIO id     -- IO actions to be run if the transaction commits
    retryVar  <- newEmptyMVar     -- full if there's something todo before retrying
    commitListeners <- newTVarIO []  
    let env = Env { commitTVar    = commitVar       :: TVar (IO () -> IO ())
                  , retryMVar     = retryVar        :: MVar (IO () -> IO ())
                  , transThreadId = tid             :: ThreadId
                  , listeners     = commitListeners :: TVar [(TMVar (),TVarValue)] 
                  , debugModeVar  = debugTVar
                  } 
    -- Setting up communication for the retry-helper thread:                  
    retryChanVar <- newIORef (Nothing :: Maybe (Chan (Maybe (IO ())))) 
    retryEndVar  <- newEmptyMVar   -- Termination signal for the retry-helper thread

    let check'retry = do 
        unsafeIOToSTM $ do
            may'todo <- tryTakeMVar retryVar
            case may'todo of
                Nothing -> return ()
                Just retryFun -> do
                    may'chan <- readIORef retryChanVar
                    chan <- case may'chan of
                                Nothing -> do
                                    chan <- newChan
                                    writeIORef retryChanVar (Just chan)
                                    spawn'retry'thread (readChan chan) (putMVar retryEndVar ())
                                    return chan
                                Just chan -> return chan
                    writeChan chan $ Just (retryFun (return()))
        debugSTM (show (tid,"Calling retry now...")) 0 debugging
        debugSTM (show (tid,"*********************"))  1000000 debugging
        S.retry

    let wait'retry'finished = do
        may'chan <- readIORef retryChanVar
        case may'chan of
            Nothing   -> return ()
            Just chan -> do
                -- Write an "EOF" on the channel and block until the helper thread is done:
                writeChan chan Nothing
                takeMVar retryEndVar

    let wrappedAction = runReaderT action env `S.orElse` (check'retry)

    result <- S.atomically $ do 
        debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging
        result <- wrappedAction
        debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging
        ls     <- readTVar commitListeners 
        -- Notify the TPVars that we're entering onCommit mode:
        debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging
        mapM_ (\(l,_) -> 
                takeTMVar l       -- tell the TVar that we're going into onCommit mode
              ) ls
        return result

    let rollbackOnCommit = do
            debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging
            wait'retry'finished
            S.atomically $ do
                ls <- readTVar commitListeners
                mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do 
                        writeTVar oldValTVar oldVal  -- smthg went wrong, restore the old value 
                        putTMVar l ()                -- ...and unblock the TVar
                      ) ls
 
    -- Now try to run the onCommit IO actions:
    commitFun <- S.atomically $ readTVar commitVar
    commitFun (return ()) `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e)

    -- Everything's ok, so notify and unblock the TVars:
    debug (show (tid,"*********************")) 1000000 debugging
    debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging
    S.atomically $ do 
        ls <- readTVar commitListeners 
        mapM_ (\(l,_) -> do 
                putTMVar l ()                  
              ) ls
    wait'retry'finished
    debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging
    debug "************************************************************" 0 debugging
    return result

    where
     -- Helper thread for the retry IO-actions
    spawn'retry'thread :: IO (Maybe (IO ())) -> IO () -> IO ThreadId
    spawn'retry'thread nextJob atEndAction = forkIO $ loop
      where loop = do
              may'job <- nextJob
              case may'job of
                Nothing -> atEndAction
                Just job -> do 
                    res <- try job          
                    case res of
                        Left (e::SomeException)  -> throw e
                        Right _ -> loop
     
--------------------------------------------------------------------------------

unsafeIOToAdvSTM :: IO a -> AdvSTM a
unsafeIOToAdvSTM = liftAdv . unsafeIOToSTM

--------------------------------------------------------------------------------

-- | Switches the debug mode on or off
-- /WARNING:/ Can lead to deadlocks! 
debugMode :: Bool -> AdvSTM ()
debugMode switch = do
    debVar <- AdvSTM $ asks debugModeVar 
    liftAdv $ writeTVar debVar switch

-- | Uses unsafeIOToSTM to output the Thread Id and a message and delays 
-- for a given number of ms
-- /WARNING:/ Can lead to deadlocks! 
debugAdvSTM :: String -> Int -> AdvSTM ()
debugAdvSTM msg delay = do 
    debVar <- AdvSTM $ asks debugModeVar
    debugging <- liftAdv $ readTVar debVar
    tid    <- AdvSTM $ asks transThreadId 
    liftAdv $ debugSTM (show (tid,msg)) delay debugging 

--------------------------------------------------------------------------------
-- The following functions are not exported as they aren't needed by 
-- the enduser 

runWith :: Env -> AdvSTM t -> S.STM t
runWith env (AdvSTM action) = runReaderT action env
                            
 
unlifter :: AdvSTM (AdvSTM a -> S.STM a)
unlifter = do
    env <- AdvSTM ask
    return (runWith env)
 
unlift :: AdvSTM a -> AdvSTM (S.STM a)
unlift f = do
    u <- unlifter
    return (u f)
 
unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a)
unlift1 f = do
    u <- unlifter
    return (\x -> u (f x))
 
-- WARNING: Can lead to deadlocks! 
debugSTM :: String -> Int -> Bool -> S.STM ()
debugSTM msg delay debugging = 
    when (debugging) $ unsafeIOToSTM $ putStrLn msg >> threadDelay delay

-- WARNING: Can lead to deadlocks! 
debug :: String -> Int -> Bool -> IO ()
debug msg delay debugging = when (debugging) $ putStrLn msg >> threadDelay delay