{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE InstanceSigs        #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE NumericUnderscores  #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Control.Monad.Class.MonadTimer.SI
  ( -- * Type classes
    MonadDelay (..)
  , MonadTimer (..)
    -- * Auxiliary functions
  , diffTimeToMicrosecondsAsInt
  , microsecondsAsIntToDiffTime
    -- * Re-exports
  , DiffTime
  , MonadFork
  , MonadMonotonicTime
  , MonadTime
  , TimeoutState (..)
    -- * Default implementations
  , defaultRegisterDelay
  , defaultRegisterDelayCancellable  
  ) where

import           Control.Concurrent.Class.MonadSTM
import           Control.Exception (assert)
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadTime.SI
import qualified Control.Monad.Class.MonadTimer as MonadTimer
import           Control.Monad.Class.MonadTimer.NonStandard (TimeoutState (..))
import qualified Control.Monad.Class.MonadTimer.NonStandard as NonStandard

import           Control.Monad.Reader

import           Data.Bifunctor (bimap)
import           Data.Functor (($>))
import           Data.Time.Clock (diffTimeToPicoseconds)



-- | Convert 'DiffTime' in seconds to microseconds represented by an 'Int'.
--
-- Note that on 32bit systems it can only represent `2^31-1` seconds, which is
-- only ~35 minutes.
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d =
    let usec :: Integer
        usec :: Integer
usec = DiffTime -> Integer
diffTimeToPicoseconds DiffTime
d forall a. Integral a => a -> a -> a
`div` Integer
1_000_000 in
    forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Integer
usec forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Int)) forall a b. (a -> b) -> a -> b
$
    forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
usec


-- | Convert time in microseconds in 'DiffTime' (measured in seconds).
--
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime = (forall a. Fractional a => a -> a -> a
/ DiffTime
1_000_000) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral

class ( MonadTimer.MonadDelay m
      , MonadMonotonicTime m
      ) => MonadDelay m where
  threadDelay :: DiffTime -> m ()

-- | Thread delay.  When the delay is smaller than what `Int` can represent it
-- will use the `Control.Monad.Class.MonadTimer.threadDelay` (e.g. for the `IO`
-- monad it will use `Control.Concurrent.threadDelay`); otherwise it will
-- recursively call `Control.Monad.Class.MonadTimer.threadDelay`.
--
instance MonadDelay IO where
  threadDelay :: forall m.
                 MonadDelay m
              => DiffTime -> m ()
  threadDelay :: forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
d | DiffTime
d forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
      forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
    where
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

  threadDelay DiffTime
d = do
      Time
c <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      let u :: Time
u = DiffTime
d DiffTime -> Time -> Time
`addTime` Time
c
      Time -> Time -> m ()
go Time
c Time
u
    where
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

      go :: Time -> Time -> m ()
      go :: Time -> Time -> m ()
go Time
c Time
u = do
        if DiffTime
d' forall a. Ord a => a -> a -> Bool
>= DiffTime
maxDelay
          then do
            forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay forall a. Bounded a => a
maxBound
            Time
c' <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
            Time -> Time -> m ()
go  Time
c' Time
u
          else
            forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d')
        where
          d' :: DiffTime
d' = Time
u Time -> Time -> DiffTime
`diffTime` Time
c

instance MonadDelay m => MonadDelay (ReaderT r m) where
  threadDelay :: DiffTime -> ReaderT r m ()
threadDelay = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay

class ( MonadTimer.MonadTimer m
      , MonadMonotonicTime m
      ) => MonadTimer m where

  -- | A register delay function which safe on 32-bit systems.
  registerDelay            :: DiffTime -> m (TVar m Bool)

  -- | A cancellable register delay which is safe on 32-bit systems and efficient
  -- for delays smaller than what `Int` can represent (especially on systems which
  -- support native timer manager).
  --
  registerDelayCancellable :: DiffTime -> m (STM m TimeoutState, m ())

  -- | A timeout function.
  --
  -- TODO: 'IO' instance is not safe on 32-bit systems.
  timeout                  :: DiffTime -> m a -> m (Maybe a)


-- | A default implementation of `registerDelay` which supports delays longer
-- then `Int`; this is especially important on 32-bit systems where maximum
-- delay expressed in microseconds is around 35 minutes.
--
defaultRegisterDelay :: forall m timeout.
                        ( MonadFork m
                        , MonadMonotonicTime m
                        , MonadSTM m
                        )
                     => NonStandard.NewTimeout m timeout
                     -> NonStandard.AwaitTimeout m timeout
                     -> DiffTime
                     -> m (TVar m Bool)
defaultRegisterDelay :: forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> AwaitTimeout m timeout -> DiffTime -> m (TVar m Bool)
defaultRegisterDelay NewTimeout m timeout
newTimeout AwaitTimeout m timeout
awaitTimeout DiffTime
d = do
    Time
c <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
    TVar m Bool
v <- forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadSTM m => a -> STM m (TVar m a)
newTVar Bool
False
    ThreadId m
tid <- forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO forall a b. (a -> b) -> a -> b
$ TVar m Bool -> Time -> Time -> m ()
go TVar m Bool
v Time
c (DiffTime
d DiffTime -> Time -> Time
`addTime` Time
c)
    forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
"delay-thread"
    forall (m :: * -> *) a. Monad m => a -> m a
return TVar m Bool
v
  where
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

    go :: TVar m Bool -> Time -> Time -> m ()
    go :: TVar m Bool -> Time -> Time -> m ()
go TVar m Bool
v Time
c Time
u | Time
u Time -> Time -> DiffTime
`diffTime` Time
c forall a. Ord a => a -> a -> Bool
>= DiffTime
maxDelay = do
      Bool
_ <- NewTimeout m timeout
newTimeout forall a. Bounded a => a
maxBound forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. AwaitTimeout m timeout
awaitTimeout
      Time
c' <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      TVar m Bool -> Time -> Time -> m ()
go TVar m Bool
v Time
c' Time
u

    go TVar m Bool
v Time
c Time
u = do
      timeout
t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt forall a b. (a -> b) -> a -> b
$ Time
u Time -> Time -> DiffTime
`diffTime` Time
c)
      forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
        Bool
_ <- AwaitTimeout m timeout
awaitTimeout timeout
t
        forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m Bool
v Bool
True


-- | A cancellable register delay which is safe on 32-bit systems and efficient
-- for delays smaller than what `Int` can represent (especially on systems which
-- support native timer manager).
--
defaultRegisterDelayCancellable :: forall m timeout.
                                   ( MonadFork m
                                   , MonadMonotonicTime m
                                   , MonadSTM m
                                   )
                                => NonStandard.NewTimeout    m timeout
                                -> NonStandard.ReadTimeout   m timeout
                                -> NonStandard.CancelTimeout m timeout
                                -> NonStandard.AwaitTimeout  m timeout
                                -> DiffTime
                                -> m (STM m TimeoutState, m ())

defaultRegisterDelayCancellable :: forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> ReadTimeout m timeout
-> CancelTimeout m timeout
-> AwaitTimeout m timeout
-> DiffTime
-> m (STM m TimeoutState, m ())
defaultRegisterDelayCancellable NewTimeout m timeout
newTimeout ReadTimeout m timeout
readTimeout CancelTimeout m timeout
cancelTimeout AwaitTimeout m timeout
_awaitTimeout DiffTime
d | DiffTime
d forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay = do
    timeout
t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
    forall (m :: * -> *) a. Monad m => a -> m a
return (ReadTimeout m timeout
readTimeout timeout
t, CancelTimeout m timeout
cancelTimeout timeout
t)
  where
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

defaultRegisterDelayCancellable NewTimeout m timeout
newTimeout ReadTimeout m timeout
_readTimeout CancelTimeout m timeout
_cancelTimeout AwaitTimeout m timeout
awaitTimeout DiffTime
d = do
    -- current time
    Time
c <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
    -- timeout state
    TVar m TimeoutState
v <- forall (m :: * -> *) a. MonadSTM m => a -> m (TVar m a)
newTVarIO TimeoutState
TimeoutPending
    ThreadId m
tid <- forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO forall a b. (a -> b) -> a -> b
$ TVar m TimeoutState -> Time -> Time -> m ()
go TVar m TimeoutState
v Time
c (DiffTime
d DiffTime -> Time -> Time
`addTime` Time
c)
    forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
"delay-thread"
    let cancel :: m ()
cancel = forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m TimeoutState
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          TimeoutState
TimeoutCancelled -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          TimeoutState
TimeoutFired     -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          TimeoutState
TimeoutPending   -> forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
v TimeoutState
TimeoutCancelled
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m TimeoutState
v, m ()
cancel)
  where
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

    go :: TVar m TimeoutState
       -> Time
       -> Time
       -> m ()
    go :: TVar m TimeoutState -> Time -> Time -> m ()
go TVar m TimeoutState
v Time
c Time
u | Time
u Time -> Time -> DiffTime
`diffTime` Time
c forall a. Ord a => a -> a -> Bool
>= DiffTime
maxDelay = do
      timeout
t <- NewTimeout m timeout
newTimeout forall a. Bounded a => a
maxBound
      TimeoutState
ts <- forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
        (forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m TimeoutState
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
           a :: TimeoutState
a@TimeoutState
TimeoutCancelled -> forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutState
a
           TimeoutState
TimeoutFired       -> forall a. (?callStack::CallStack) => String -> a
error String
"registerDelayCancellable: invariant violation!"
           TimeoutState
TimeoutPending     -> forall (m :: * -> *) a. MonadSTM m => STM m a
retry)
        forall (m :: * -> *) a. MonadSTM m => STM m a -> STM m a -> STM m a
`orElse`
        -- the overall timeout is still pending when 't' fires
        (AwaitTimeout m timeout
awaitTimeout timeout
t forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TimeoutState
TimeoutPending)
      case TimeoutState
ts of
        TimeoutState
TimeoutPending -> do
          Time
c' <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          TVar m TimeoutState -> Time -> Time -> m ()
go TVar m TimeoutState
v Time
c' Time
u
        TimeoutState
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()

    go TVar m TimeoutState
v Time
c Time
u = do
      timeout
t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt forall a b. (a -> b) -> a -> b
$ Time
u Time -> Time -> DiffTime
`diffTime` Time
c)
      forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
        TimeoutState
ts <- (forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m TimeoutState
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                 a :: TimeoutState
a@TimeoutState
TimeoutCancelled -> forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutState
a
                 TimeoutState
TimeoutFired       -> forall a. (?callStack::CallStack) => String -> a
error String
"registerDelayCancellable: invariant violation!"
                 TimeoutState
TimeoutPending     -> forall (m :: * -> *) a. MonadSTM m => STM m a
retry)
              forall (m :: * -> *) a. MonadSTM m => STM m a -> STM m a -> STM m a
`orElse`
              -- the overall timeout fires when 't' fires
              (AwaitTimeout m timeout
awaitTimeout timeout
t forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TimeoutState
TimeoutFired)
        case TimeoutState
ts of
          TimeoutState
TimeoutFired -> forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
v TimeoutState
TimeoutFired
          TimeoutState
_            -> forall (m :: * -> *) a. Monad m => a -> m a
return ()


-- | Like 'GHC.Conc.registerDelay' but safe on 32-bit systems.  When the delay
-- is larger than what `Int` can represent it will fork a thread which will
-- write to the returned 'TVar' once the delay has passed.  When the delay is
-- small enough it will use the `MonadTimer`'s `registerDelay` (e.g. for `IO`
-- monad it will use the `GHC`'s `GHC.Conc.registerDelay`).
--
-- TODO: 'timeout' not safe on 32-bit systems.
instance MonadTimer IO where
  registerDelay :: DiffTime -> IO (TVar IO Bool)
registerDelay DiffTime
d
      | DiffTime
d forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
        forall (m :: * -> *). MonadTimer m => Int -> m (TVar m Bool)
MonadTimer.registerDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
      | Bool
otherwise =
        forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> AwaitTimeout m timeout -> DiffTime -> m (TVar m Bool)
defaultRegisterDelay
          NewTimeout IO Timeout
NonStandard.newTimeout
          AwaitTimeout IO Timeout
NonStandard.awaitTimeout
          DiffTime
d
    where
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime forall a. Bounded a => a
maxBound

  registerDelayCancellable :: DiffTime -> IO (STM IO TimeoutState, IO ())
registerDelayCancellable =
    forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> ReadTimeout m timeout
-> CancelTimeout m timeout
-> AwaitTimeout m timeout
-> DiffTime
-> m (STM m TimeoutState, m ())
defaultRegisterDelayCancellable 
      NewTimeout IO Timeout
NonStandard.newTimeout
      ReadTimeout IO Timeout
NonStandard.readTimeout
      CancelTimeout IO Timeout
NonStandard.cancelTimeout
      AwaitTimeout IO Timeout
NonStandard.awaitTimeout

  timeout :: forall a. DiffTime -> IO a -> IO (Maybe a)
timeout = forall (m :: * -> *) a. MonadTimer m => Int -> m a -> m (Maybe a)
MonadTimer.timeout forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Int
diffTimeToMicrosecondsAsInt

instance MonadTimer m => MonadTimer (ReaderT r m) where
  registerDelay :: DiffTime -> ReaderT r m (TVar (ReaderT r m) Bool)
registerDelay            = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadTimer m => DiffTime -> m (TVar m Bool)
registerDelay
  registerDelayCancellable :: DiffTime
-> ReaderT r m (STM (ReaderT r m) TimeoutState, ReaderT r m ())
registerDelayCancellable = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadTimer m =>
DiffTime -> m (STM m TimeoutState, m ())
registerDelayCancellable
  timeout :: forall a. DiffTime -> ReaderT r m a -> ReaderT r m (Maybe a)
timeout DiffTime
d ReaderT r m a
f              = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \r
r -> forall (m :: * -> *) a.
MonadTimer m =>
DiffTime -> m a -> m (Maybe a)
timeout DiffTime
d (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r m a
f r
r)