{-# LANGUAGE ScopedTypeVariables, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.AdvSTM
-- Copyright   :  (c) Chris Kuklewicz 2006, Peter Robinson 2009
-- License     :  BSD3
--
-- Maintainer  :  Peter Robinson <robinson@ecs.tuwien.ac.at>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Extends Control.Concurrent.STM with IO hooks
--
-----------------------------------------------------------------------------

module Control.Concurrent.AdvSTM( -- * Class MonadAdvSTM
                                  MonadAdvSTM( onCommitWith
                                            , onCommit
                                            , unsafeRetryWith
                                            , orElse
                                            , retry
                                            , check
                                            , catchSTM
                                            , liftAdv
                                            , newTVar
                                            , readTVarAsync
                                            , writeTVarAsync
                                            , readTVar
                                            , writeTVar
                                            )

                                  -- * Monad AdvSTM
                                , AdvSTM
                                , atomically
                                , unsafeIOToSTM
                                , handleSTM
                                , debugAdvSTM
                                , debugMode
                                ) where

import Prelude hiding (catch)
import Control.Monad.AdvSTM.Def(AdvSTM(..),Env(..),TVarValue(..))
import Control.Monad.AdvSTM.Class(
    MonadAdvSTM(..), handleSTM, TVar( TVar ), onCommitLock, currentTid, valueTVar )
-- import Control.Monad.Reader(ReaderT(ReaderT),mapReaderT,runReaderT)

import Control.Exception(Exception,throw,catch,SomeException,fromException,try,mask_,Deadlock(..),finally)
import Control.Monad(mplus,when,liftM,ap,unless)
import Control.Monad.Except
import Control.Concurrent(threadDelay,forkIO,ThreadId,myThreadId,throwTo)
import Control.Concurrent.Chan(Chan,newChan,readChan,writeChan)
import Control.Concurrent.STM.TMVar(TMVar,putTMVar,takeTMVar,newTMVar,tryTakeTMVar)
-- import Control.Concurrent.STM.TChan(TChan,writeTChan)
import Control.Concurrent.MVar(MVar,newMVar,takeMVar,tryTakeMVar,putMVar,tryPutMVar,swapMVar)
import qualified Control.Concurrent.STM as S (STM,orElse,retry,catchSTM,atomically,check)
import qualified Control.Concurrent.STM.TVar as STVar -- (TVar,newTVarIO,readTVar,writeTVar)
import qualified GHC.Conc as G -- unsafeIOToSTM
import Data.IORef(newIORef,readIORef,writeIORef)
import Data.Maybe(isJust,Maybe,fromJust)
import Control.Monad.Reader(MonadReader(..),ReaderT(ReaderT),runReaderT,lift,asks)

--------------------------------------------------------------------------------


instance MonadAdvSTM (AdvSTM) where

    onCommitWith ioclosure = do
        commitCl      <- AdvSTM $ asks commitClosure
        liftAdv $ STVar.writeTVar commitCl ioclosure

    onCommit ioaction = do
        commitVar     <- AdvSTM $ asks commitTVar
        commitActions <- liftAdv $ STVar.readTVar commitVar
        liftAdv $ STVar.writeTVar commitVar $ ioaction : commitActions

--    unsafeOnRetry ioaction = do
--        retryVar <- AdvSTM $ asks retryMVar
--        liftAdv . unsafeIOToSTM $ do
--          may'retryFun <- tryTakeMVar retryVar
--          let retryFun = maybe (ioaction >>) (. (ioaction >>)) may'retryFun
--          putMVar retryVar $! retryFun

    orElse = mplus

    retry = liftAdv S.retry

    check = liftAdv . S.check

    catchSTM action handler = do
        action'  <- unlift action
        handler' <- unlift1 handler
        let handler'' e = case fromException e of
                            Nothing -> throw e
                            Just e' -> handler' e'
        liftAdv $ S.catchSTM action' handler''

    liftAdv = AdvSTM . lift


    -- | See 'STVar.newTVar'
    newTVar a = TVar `liftM` liftAdv (STVar.newTVar a)
                     `ap`    liftAdv (newTMVar ())
                     `ap`    liftAdv (STVar.newTVar Nothing)


    -- | Writes a value to a TVar. Blocks until the onCommit IO-action(s) are
    -- complete. See 'onCommit' for details.
    writeTVar tvar a = do
        commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
        -- Get ThreadID of current transaction:
        curTid     <- AdvSTM $ asks transThreadId
        storedTid  <- liftAdv $ STVar.readTVar (currentTid tvar)
        case commitLock of
            Nothing ->
                if isJust storedTid && fromJust storedTid == curTid
                  then throw Deadlock       -- Can't write the TVar in the onCommit phase
                  else retry
            Just _  -> do
                unless (isJust storedTid && (fromJust storedTid == curTid)) $ do
                    -- First write access, update the ThreadId:
                    liftAdv $ STVar.writeTVar (currentTid tvar) $ Just curTid
                    -- Add this TVar to the onCommit-listener list:
                    lsTVar <- AdvSTM $ asks listeners
                    ls     <- liftAdv $ STVar.readTVar lsTVar
                    -- Remember the old value for rollback:
                    oldval <- liftAdv $ STVar.readTVar (valueTVar tvar)
                    liftAdv $ STVar.writeTVar lsTVar $
                            (onCommitLock tvar,TVarValue (valueTVar tvar,oldval)) : ls

                liftAdv $ STVar.writeTVar (valueTVar tvar) a
                liftAdv $ putTMVar (onCommitLock tvar) ()


    writeTVarAsync tvar = liftAdv . STVar.writeTVar (valueTVar tvar)

    --------------------------------------------------------------------------------

    -- | Reads a value from a TVar. Blocks until the IO onCommit action(s) of
    -- the corresponding transaction are complete.
    -- See 'onCommit' for a more detailed description of this behaviour.
    readTVar tvar = do
        commitLock <- liftAdv $ tryTakeTMVar (onCommitLock tvar)
        case commitLock of
            Nothing -> do
                storedTid <- liftAdv $ STVar.readTVar (currentTid tvar)
                curTid    <- AdvSTM $ asks transThreadId
                if isJust storedTid && fromJust storedTid == curTid
                  then throw Deadlock
                  else retry
            Just _ -> do
                result <- liftAdv $ STVar.readTVar $ valueTVar tvar
                liftAdv $ putTMVar (onCommitLock tvar) ()
                return result


    readTVarAsync = liftAdv . STVar.readTVar . valueTVar


    -- | Forks a separate thread to run the IO action and then retries the transaction.
    unsafeRetryWith io = do -- unsafeOnRetry io >> retry
      doneMVar <- AdvSTM $ asks retryDoneMVar
      unsafeIOToSTM $ forkIO $ (do
        val <- takeMVar doneMVar
        case val of
          Nothing -> return ()
          Just _  -> io
        ) `finally` tryPutMVar doneMVar (Just ())
      retry

    unsafeIOToSTM = liftAdv . G.unsafeIOToSTM

--------------------------------------------------------------------------------


-- | See 'S.atomically'
atomically :: AdvSTM a -> IO a
atomically (AdvSTM action) = do
    let debugging = False -- Switching this to True occasionally causes deadlocks (b/c unsafeIO)!
    debugTVar <- STVar.newTVarIO debugging
    tid       <- myThreadId       -- ThreadId for controlling TVar access
    debug "***************************************************************" 0 debugging
    debug (show (tid,"Starting transaction...")) 1000000 debugging
    -- Building the Reader monad environment
    commitVar <- STVar.newTVarIO []     -- IO actions to be run if the transaction commits
    commitCl  <- STVar.newTVarIO (sequence_)
      -- IO action that \"contains\" the commit IO actions.
    retryDoneMVar   <- newMVar $ Just ()
    commitListeners <- STVar.newTVarIO []
    let env = Env { commitTVar    = commitVar       :: STVar.TVar [IO ()] -- -> IO ())
                  , commitClosure = commitCl
                  , retryDoneMVar = retryDoneMVar   :: MVar (Maybe ()) -- (IO () -> IO ())
                  , transThreadId = tid             :: ThreadId
                  , listeners     = commitListeners :: STVar.TVar [(TMVar (),TVarValue)]
                  , debugModeVar  = debugTVar
                  }

    let stopRetryWith = swapMVar retryDoneMVar Nothing

    let wrappedAction = runReaderT action env -- `S.orElse` check'retry

    -- Block exceptions from other threads for the rest of 'atomically'
    mask_ $ do
        result <- S.atomically $ do
            debugSTM (show (tid,"wrappedAction: Running S.STM action...")) 0 debugging
            result <- wrappedAction
            debugSTM (show (tid,"wrappedAction: Finished S.STM action...")) 0 debugging
            ls     <- STVar.readTVar commitListeners
            -- Notify the TPVars that we're entering onCommit mode:
            debugSTM (show (tid,"wrappedAction: Notifying TVars that we're about to run onCommit...")) 0 debugging
            mapM_ (\(l,_) ->
                    takeTMVar l    -- tell the TVar that we're going into onCommit mode
                  ) ls
            return result

        let rollbackOnCommit = do
                debug "rollbackOnCommit: rolling back modified TVar values!" 0 debugging
                S.atomically $ do
                    ls <- STVar.readTVar commitListeners
                    mapM_ (\(l,TVarValue (oldValTVar,oldVal)) -> do
                            STVar.writeTVar oldValTVar oldVal  -- smthg went wrong, restore the old value
                            putTMVar l ()                -- ...and unblock the TVar
                          ) ls

        -- Wait for the retryWith thread(s) to be done before running the onCommit
        -- actions.
        stopRetryWith

        -- Now try to run the onCommit IO actions:
        commitAcs  <- liftM reverse $ S.atomically $ STVar.readTVar commitVar
        commitClAc <- S.atomically $ STVar.readTVar commitCl
        commitClAc [ ac | ac <- commitAcs ]
          `catch` (\(e::SomeException) -> rollbackOnCommit >> throw e)

        -- Everything's ok, so notify and unblock the TVars:
        debug (show (tid,"*********************")) 1000000 debugging
        debug (show (tid,"Notifying TPVars that we're done:")) 0 debugging
        S.atomically $ do
            ls <- STVar.readTVar commitListeners
            mapM_ (\(l,_) ->
                    putTMVar l ()
                  ) ls
        debug (show (tid,"Transaction done; retry thread finished")) 1000000 debugging
        debug "************************************************************" 0 debugging

        return result

--------------------------------------------------------------------------------

-- | Switches the debug mode on or off.
-- /WARNING:/ Can lead to deadlocks!
debugMode :: Bool -> AdvSTM ()
debugMode switch = do
    debVar <- AdvSTM $ asks debugModeVar
    liftAdv $ STVar.writeTVar debVar switch

-- | Uses unsafeIOToSTM to output the Thread Id and a message and delays
-- for a given number of time.
-- /WARNING:/ Can lead to deadlocks!
debugAdvSTM :: String -> Int -> AdvSTM ()
debugAdvSTM msg delay = do
    debVar <- AdvSTM $ asks debugModeVar
    debugging <- liftAdv $ STVar.readTVar debVar
    tid    <- AdvSTM $ asks transThreadId
    liftAdv $ debugSTM (show (tid,msg)) delay debugging

--------------------------------------------------------------------------------
-- The following functions are not exported as they aren't needed by
-- the enduser

runWith :: Env -> AdvSTM t -> S.STM t
runWith env (AdvSTM action) = runReaderT action env


unlifter :: AdvSTM (AdvSTM a -> S.STM a)
unlifter = do
    env <- AdvSTM ask
    return (runWith env)

unlift :: AdvSTM a -> AdvSTM (S.STM a)
unlift f = do
    u <- unlifter
    return (u f)

unlift1 :: (t -> AdvSTM a) -> AdvSTM (t -> S.STM a)
unlift1 f = do
    u <- unlifter
    return (u . f)

-- WARNING: Can lead to deadlocks!
debugSTM :: String -> Int -> Bool -> S.STM ()
debugSTM msg delay debugging =
    when debugging $ G.unsafeIOToSTM $ putStrLn msg >> threadDelay delay

-- WARNING: Can lead to deadlocks!
debug :: String -> Int -> Bool -> IO ()
debug msg delay debugging = when debugging $ putStrLn msg >> threadDelay delay

{-
instance forall a. (Exception e, a) => MonadError e (AdvSTM a) where
  throwError = throw
  catchError = catchSTM
-}