{-# LANGUAGE CPP                #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeFamilies       #-}

#if  defined(__GLASGOW_HASKELL__) && \
    !defined(mingw32_HOST_OS) && \
    !defined(__GHCJS__) && \
    !defined(js_HOST_ARCH) && \
    !defined(wasm32_HOST_ARCH)
#define GHC_TIMERS_API
#endif

-- | A non-standard interface for timer api.
--
-- This module also provides a polyfill which allows to use timer api also on
-- non-threaded RTS regardless of the architecture \/ OS.  Currently we support
-- `*nix`, `macOS`, `Windows` (and, unofficially `GHCJS`).
--
-- We use it to provide @'Control.Monad.Class.MonadTimer.MonadTimer' 'IO'@
-- instance and to implement a cancellable timers, see
-- 'Control.Monad.Class.MonadTimer.SI.registerDelayCancellable'.
--
-- You can expect we will deprecate it at some point (e.g. once GHC gets
-- a better support for timers especially across different execution
-- environments).
--
module Control.Monad.Class.MonadTimer.NonStandard
  ( TimeoutState (..)
  , newTimeout
  , readTimeout
  , cancelTimeout
  , awaitTimeout
  , NewTimeout
  , ReadTimeout
  , CancelTimeout
  , AwaitTimeout
  ) where

import Control.Concurrent.STM qualified as STM
#ifndef GHC_TIMERS_API
import Control.Monad (when)
#endif
import Control.Monad.Class.MonadSTM

#ifdef GHC_TIMERS_API
import GHC.Event qualified as GHC (TimeoutKey, getSystemTimerManager,
           registerTimeout, unregisterTimeout)
#else
import GHC.Conc.IO qualified as GHC (registerDelay)
#endif


-- | State of a timeout: pending, fired or cancelled.
--
data TimeoutState = TimeoutPending | TimeoutFired | TimeoutCancelled
  deriving (TimeoutState -> TimeoutState -> Bool
(TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool) -> Eq TimeoutState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TimeoutState -> TimeoutState -> Bool
== :: TimeoutState -> TimeoutState -> Bool
$c/= :: TimeoutState -> TimeoutState -> Bool
/= :: TimeoutState -> TimeoutState -> Bool
Eq, Eq TimeoutState
Eq TimeoutState =>
(TimeoutState -> TimeoutState -> Ordering)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> TimeoutState)
-> (TimeoutState -> TimeoutState -> TimeoutState)
-> Ord TimeoutState
TimeoutState -> TimeoutState -> Bool
TimeoutState -> TimeoutState -> Ordering
TimeoutState -> TimeoutState -> TimeoutState
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: TimeoutState -> TimeoutState -> Ordering
compare :: TimeoutState -> TimeoutState -> Ordering
$c< :: TimeoutState -> TimeoutState -> Bool
< :: TimeoutState -> TimeoutState -> Bool
$c<= :: TimeoutState -> TimeoutState -> Bool
<= :: TimeoutState -> TimeoutState -> Bool
$c> :: TimeoutState -> TimeoutState -> Bool
> :: TimeoutState -> TimeoutState -> Bool
$c>= :: TimeoutState -> TimeoutState -> Bool
>= :: TimeoutState -> TimeoutState -> Bool
$cmax :: TimeoutState -> TimeoutState -> TimeoutState
max :: TimeoutState -> TimeoutState -> TimeoutState
$cmin :: TimeoutState -> TimeoutState -> TimeoutState
min :: TimeoutState -> TimeoutState -> TimeoutState
Ord, Int -> TimeoutState -> ShowS
[TimeoutState] -> ShowS
TimeoutState -> String
(Int -> TimeoutState -> ShowS)
-> (TimeoutState -> String)
-> ([TimeoutState] -> ShowS)
-> Show TimeoutState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TimeoutState -> ShowS
showsPrec :: Int -> TimeoutState -> ShowS
$cshow :: TimeoutState -> String
show :: TimeoutState -> String
$cshowList :: [TimeoutState] -> ShowS
showList :: [TimeoutState] -> ShowS
Show)


-- | The type of the timeout handle, used with 'newTimeout', 'readTimeout', and
-- 'cancelTimeout'.
--
#ifdef GHC_TIMERS_API
data Timeout = TimeoutIO !(STM.TVar TimeoutState) !GHC.TimeoutKey
#else
data Timeout = TimeoutIO !(STM.TVar (STM.TVar Bool)) !(STM.TVar Bool)
#endif

-- | Create a new timeout which will fire at the given time duration in
-- the future.
--
-- The timeout will start in the 'TimeoutPending' state and either
-- fire at or after the given time leaving it in the 'TimeoutFired' state,
-- or it may be cancelled with 'cancelTimeout', leaving it in the
-- 'TimeoutCancelled' state.
--
-- Timeouts /cannot/ be reset to the pending state once fired or cancelled
-- (as this would be very racy). You should create a new timeout if you need
-- this functionality.
--
-- When native timer manager is supported (on `*nix` systems), it only holds
-- a `TVar` with `TimeoutState` and `GHC.TimeoutKey`.
--
newTimeout :: NewTimeout IO Timeout
type NewTimeout m timeout = Int -> m timeout


-- | Read the current state of a timeout. This does not block, but returns
-- the current state. It is your responsibility to use 'retry' to wait.
--
-- Alternatively you may wish to use the convenience utility 'awaitTimeout'
-- to wait for just the fired or cancelled outcomes.
--
-- You should consider the cancelled state if you plan to use 'cancelTimeout'.
--
readTimeout :: ReadTimeout IO Timeout
type ReadTimeout m timeout = timeout -> STM m TimeoutState


-- | Cancel a timeout (unless it has already fired), putting it into the
-- 'TimeoutCancelled' state. Code reading and acting on the timeout state
-- need to handle such cancellation appropriately.
--
-- It is safe to race this concurrently against the timer firing. It will
-- have no effect if the timer fires first.
--
cancelTimeout :: CancelTimeout IO Timeout
type CancelTimeout m timeout = timeout -> m ()

-- | Returns @True@ when the timeout is fired, or @False@ if it is cancelled.
awaitTimeout :: AwaitTimeout IO Timeout
type AwaitTimeout m timeout = timeout -> STM m Bool


#ifdef GHC_TIMERS_API

readTimeout :: ReadTimeout IO Timeout
readTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
_key) = TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var

newTimeout :: NewTimeout IO Timeout
newTimeout = \Int
d -> do
    TVar TimeoutState
var <- TimeoutState -> IO (TVar TimeoutState)
forall a. a -> IO (TVar a)
STM.newTVarIO TimeoutState
TimeoutPending
    TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
    TimeoutKey
key <- TimerManager -> Int -> TimeoutCallback -> IO TimeoutKey
GHC.registerTimeout TimerManager
mgr Int
d (STM () -> TimeoutCallback
forall a. STM a -> IO a
STM.atomically (TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var))
    Timeout -> IO Timeout
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar TimeoutState -> TimeoutKey -> Timeout
TimeoutIO TVar TimeoutState
var TimeoutKey
key)
  where
    timeoutAction :: TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var = do
      TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
      case TimeoutState
x of
        TimeoutState
TimeoutPending   -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutFired
        TimeoutState
TimeoutFired     -> String -> STM ()
forall a. HasCallStack => String -> a
error String
"MonadTimer(IO): invariant violation"
        TimeoutState
TimeoutCancelled -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

cancelTimeout :: CancelTimeout IO Timeout
cancelTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
key) = do
    STM () -> TimeoutCallback
forall a. STM a -> IO a
STM.atomically (STM () -> TimeoutCallback) -> STM () -> TimeoutCallback
forall a b. (a -> b) -> a -> b
$ do
      TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
      case TimeoutState
x of
        TimeoutState
TimeoutPending   -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutCancelled
        TimeoutState
TimeoutFired     -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        TimeoutState
TimeoutCancelled -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
    TimerManager -> TimeoutKey -> TimeoutCallback
GHC.unregisterTimeout TimerManager
mgr TimeoutKey
key

#else

readTimeout (TimeoutIO timeoutvarvar cancelvar) = do
  canceled <- STM.readTVar cancelvar
  fired    <- STM.readTVar =<< STM.readTVar timeoutvarvar
  case (canceled, fired) of
    (True, _)  -> return TimeoutCancelled
    (_, False) -> return TimeoutPending
    (_, True)  -> return TimeoutFired

newTimeout d = do
  timeoutvar    <- GHC.registerDelay d
  timeoutvarvar <- STM.newTVarIO timeoutvar
  cancelvar     <- STM.newTVarIO False
  return (TimeoutIO timeoutvarvar cancelvar)

cancelTimeout (TimeoutIO timeoutvarvar cancelvar) =
  STM.atomically $ do
    fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
    when (not fired) $ STM.writeTVar cancelvar True

#endif

awaitTimeout :: AwaitTimeout IO Timeout
awaitTimeout Timeout
t  = do TimeoutState
s <- ReadTimeout IO Timeout
readTimeout Timeout
t
                     case TimeoutState
s of
                       TimeoutState
TimeoutPending   -> STM Bool
STM IO Bool
forall a. STM IO a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
                       TimeoutState
TimeoutFired     -> Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                       TimeoutState
TimeoutCancelled -> Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False