-- | A module with time measuring primitives that might not work in all monads
-- that building allows.
--
-- Measures are collected only if the environment variable
-- @DEBUG_TIMESTATS_ENABLE@ is set to any value ahead of invoking any function
-- in this module.
--
module Debug.TimeStats.Unsafe
  ( -- * Measuring
    unsafeMeasureM
  ) where

import Debug.TimeStats
         ( TimeStats(..)
         , TimeStatsRef
         , enabled
         , lookupTimeStatsRef
         , updateTimeStatsRef
         )
import GHC.Clock (getMonotonicTimeNSec)
import System.IO.Unsafe (unsafePerformIO)

-- | Measure the time it takes to run the action.
--
-- Add the time to the stats of the given label and increase its count by one.
--
-- 'measureM' keeps the stats in a globally available store in order to minimize
-- the changes necessary when instrumenting a program. Otherwise a reference to
-- the store would need to be passed to every function that might invoke
-- functions that need this reference.
--
-- A time measure isn't collected if the given action fails with an exception.
-- This is a deliberate choice to demand less of the monad in which measures are
-- taken.
--
-- Time measures aren't collected either if the environment variable
-- @DEBUG_TIMESTATS_ENABLE@ isn't set the first time this function is
-- evaluated.
--
-- This function relies on a hack to perform IO in any monad, which does not
-- always work. In particular, we can expect it to fail in monads where
--
-- > (m >>= \_ -> undefined) == undefined -- for some computation m
--
-- An example of such a monad is the list monad
--
-- > ([()] >>= \_ -> undefined) == undefined
--
-- Another example is the @Control.Monad.Free.Free IO@.
--
-- > (Control.Monad.Free.Pure () >>= \_ -> undefined) == undefined
--
-- But it seems to work on @IO@ or @ReaderT IO@.
--
-- > seq (print () >>= \_ -> undefined) () == ()
--
-- Also, monads that run the continuation of bind multiple times might only
-- have accounted the time to run the first time only.
--
{-# INLINE unsafeMeasureM #-}
unsafeMeasureM :: Monad m => String -> m a -> m a
unsafeMeasureM :: forall (m :: * -> *) a. Monad m => String -> m a -> m a
unsafeMeasureM String
label =
    -- See the documentation of 'enabled'
    if Bool
enabled then do
          -- @ref@ is the reference to the stats associated to the label.
          -- See note [Looking up stats with unsafePerformIO]
      let ref :: TimeStatsRef
ref = IO TimeStatsRef -> TimeStatsRef
forall a. IO a -> a
unsafePerformIO (IO TimeStatsRef -> TimeStatsRef)
-> IO TimeStatsRef -> TimeStatsRef
forall a b. (a -> b) -> a -> b
$ String -> IO TimeStatsRef
lookupTimeStatsRef String
label
       in \m a
action -> TimeStatsRef -> m a -> m a
forall (m :: * -> *) a. Monad m => TimeStatsRef -> m a -> m a
measureMWith TimeStatsRef
ref m a
action
    else
      m a -> m a
forall a. a -> a
id

-- | Measure the time it takes to run the given action and update with it
-- the given reference to time stats.
measureMWith :: Monad m => TimeStatsRef -> m a -> m a
measureMWith :: forall (m :: * -> *) a. Monad m => TimeStatsRef -> m a -> m a
measureMWith TimeStatsRef
tref m a
m = do
    Word64
t0 <- IO Word64 -> m Word64
forall (m :: * -> *) a. Monad m => IO a -> m a
intersperseIOinM IO Word64
getMonotonicTimeNSec
    a
a <- m a
m
    IO () -> m ()
forall (m :: * -> *) a. Monad m => IO a -> m a
intersperseIOinM (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Word64
tf <- IO Word64
getMonotonicTimeNSec
      TimeStatsRef -> (TimeStats -> TimeStats) -> IO ()
updateTimeStatsRef TimeStatsRef
tref ((TimeStats -> TimeStats) -> IO ())
-> (TimeStats -> TimeStats) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TimeStats
st ->
        TimeStats
st
          { timeStat = (tf - t0) + timeStat st
          , countStat = 1 + countStat st
          }
    a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

---------------------
-- intersperseIOinM
---------------------

-- | Hack to intersperse IO actions into any monad
intersperseIOinM :: Monad m => IO a -> m a
intersperseIOinM :: forall (m :: * -> *) a. Monad m => IO a -> m a
intersperseIOinM IO a
m = do
    -- The fictitious state is only used to force @unsafePerformIO@
    -- to run @m@ every time @intersperseIOinM m@ is evaluated.
    Bool
s <- m Bool
getStateM
    case IO (Bool, a) -> (Bool, a)
forall a. IO a -> a
unsafePerformIO (IO (Bool, a) -> (Bool, a)) -> IO (Bool, a) -> (Bool, a)
forall a b. (a -> b) -> a -> b
$ (,) Bool
s (a -> (Bool, a)) -> IO a -> IO (Bool, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
m of
      (Bool
_, a
r) -> a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r
  where
    -- We mark this function as NOINLINE to ensure the compiler cannot reason
    -- by unfolding that two calls of @getStateM@ yield the same value.
    {-# NOINLINE getStateM #-}
    getStateM :: m Bool
getStateM = Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True