-- | Utilities for tracking scope: nested 'Invocation's
--
-- Intended for unqualified import.
module Debug.Provenance.Scope (
    -- * Thread-local scope
    Scope
  , scoped
  , getScope
    -- * Scope across threads
  , forkInheritScope
  , inheritScope
    -- *** Convenience re-exports
  , HasCallStack
  ) where

import Control.Concurrent
import Control.Monad
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Bifunctor
import Data.IORef
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple (swap)
import GHC.Stack
import System.IO.Unsafe (unsafePerformIO)

import Debug.Provenance.Internal

{-------------------------------------------------------------------------------
  Scope
-------------------------------------------------------------------------------}

-- | Thread-local scope
--
-- Most recent invocations are first in the list.
type Scope = [Invocation]

-- | Extend current scope
scoped :: (HasCallStack, MonadMask m, MonadIO m) => m a -> m a
scoped :: forall (m :: * -> *) a.
(HasCallStack, MonadMask m, MonadIO m) =>
m a -> m a
scoped m a
k =  (\(a
a, ()) -> a
a) ((a, ()) -> a) -> m (a, ()) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
    i <- CallSite -> m Invocation
forall (m :: * -> *). MonadIO m => CallSite -> m Invocation
newInvocationFrom CallSite
HasCallStack => CallSite
callSite -- the call to 'scoped'
    generalBracket
      (pushInvocation i)
      (\()
_ ExitCase a
_ -> m ()
forall (m :: * -> *). MonadIO m => m ()
popInvocation)
      (\()
_ -> m a
k)

-- | Get current scope
getScope :: MonadIO m => m Scope
getScope :: forall (m :: * -> *). MonadIO m => m Scope
getScope = (Scope -> (Scope, Scope)) -> m Scope
forall (m :: * -> *) a. MonadIO m => (Scope -> (Scope, a)) -> m a
modifyThreadLocalScope ((Scope -> (Scope, Scope)) -> m Scope)
-> (Scope -> (Scope, Scope)) -> m Scope
forall a b. (a -> b) -> a -> b
$ \Scope
s -> (Scope
s, Scope
s)

{-------------------------------------------------------------------------------
  Scope across threads
-------------------------------------------------------------------------------}

-- | Inherit scope from a parent thread
--
-- This sets the scope of the current thread to that of the parent. This should
-- be done prior to growing the scope of the child thread; 'inheritScope' will
-- fail with an exception if the scope in the child thread is not empty.
--
-- See also 'forkInheritScope'.
inheritScope :: MonadIO m => ThreadId -> m ()
inheritScope :: forall (m :: * -> *). MonadIO m => ThreadId -> m ()
inheritScope ThreadId
parent = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    parentScope <- Scope -> ThreadId -> Map ThreadId Scope -> Scope
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] ThreadId
parent (Map ThreadId Scope -> Scope)
-> IO (Map ThreadId Scope) -> IO Scope
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Map ThreadId Scope) -> IO (Map ThreadId Scope)
forall a. IORef a -> IO a
readIORef IORef (Map ThreadId Scope)
globalScope
    ok          <- modifyThreadLocalScope $ \Scope
childScope ->
                     if Scope -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Scope
childScope
                       then (Scope
parentScope, Bool
True)
                       else (Scope
childScope, Bool
False)
    unless ok $ fail "inheritScope: child scope non-empty"

-- | Convenience combination of 'forkIO' and 'inheritScope'
forkInheritScope :: IO () -> IO ThreadId
forkInheritScope :: IO () -> IO ThreadId
forkInheritScope IO ()
child = do
    parent <- IO ThreadId
myThreadId
    forkIO $ inheritScope parent >> child

{-------------------------------------------------------------------------------
  Internal: scope manipulation
-------------------------------------------------------------------------------}

modifyThreadLocalScope :: forall m a. MonadIO m => (Scope -> (Scope, a)) -> m a
modifyThreadLocalScope :: forall (m :: * -> *) a. MonadIO m => (Scope -> (Scope, a)) -> m a
modifyThreadLocalScope Scope -> (Scope, a)
f = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ do
    tid <- IO ThreadId
myThreadId
    atomicModifyIORef' globalScope $ swap . Map.alterF f' tid
  where
    f' :: Maybe Scope -> (a, Maybe Scope)
    f' :: Maybe Scope -> (a, Maybe Scope)
f' = (Scope -> Maybe Scope) -> (a, Scope) -> (a, Maybe Scope)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Scope -> Maybe Scope
gcIfEmpty ((a, Scope) -> (a, Maybe Scope))
-> (Maybe Scope -> (a, Scope)) -> Maybe Scope -> (a, Maybe Scope)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope, a) -> (a, Scope)
forall a b. (a, b) -> (b, a)
swap ((Scope, a) -> (a, Scope))
-> (Maybe Scope -> (Scope, a)) -> Maybe Scope -> (a, Scope)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> (Scope, a)
f (Scope -> (Scope, a))
-> (Maybe Scope -> Scope) -> Maybe Scope -> (Scope, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> Maybe Scope -> Scope
forall a. a -> Maybe a -> a
fromMaybe []

    -- Remove the entry from the map altogether if the scope is empty.
    gcIfEmpty :: Scope -> Maybe Scope
    gcIfEmpty :: Scope -> Maybe Scope
gcIfEmpty [] = Maybe Scope
forall a. Maybe a
Nothing
    gcIfEmpty Scope
s  = Scope -> Maybe Scope
forall a. a -> Maybe a
Just Scope
s

modifyThreadLocalScope_ :: MonadIO m => (Scope -> Scope) -> m ()
modifyThreadLocalScope_ :: forall (m :: * -> *). MonadIO m => (Scope -> Scope) -> m ()
modifyThreadLocalScope_ Scope -> Scope
f = (Scope -> (Scope, ())) -> m ()
forall (m :: * -> *) a. MonadIO m => (Scope -> (Scope, a)) -> m a
modifyThreadLocalScope ((,()) (Scope -> (Scope, ())) -> (Scope -> Scope) -> Scope -> (Scope, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> Scope
f)

pushInvocation :: MonadIO m => Invocation -> m ()
pushInvocation :: forall (m :: * -> *). MonadIO m => Invocation -> m ()
pushInvocation Invocation
i = (Scope -> Scope) -> m ()
forall (m :: * -> *). MonadIO m => (Scope -> Scope) -> m ()
modifyThreadLocalScope_ (Invocation
iInvocation -> Scope -> Scope
forall a. a -> [a] -> [a]
:)

popInvocation :: MonadIO m => m ()
popInvocation :: forall (m :: * -> *). MonadIO m => m ()
popInvocation = (Scope -> Scope) -> m ()
forall (m :: * -> *). MonadIO m => (Scope -> Scope) -> m ()
modifyThreadLocalScope_ ((Scope -> Scope) -> m ()) -> (Scope -> Scope) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
    []  -> String -> Scope
forall a. HasCallStack => String -> a
error String
"popInvocation: empty stack"
    Invocation
_:Scope
s -> Scope
s

{-------------------------------------------------------------------------------
  Internal: globals
-------------------------------------------------------------------------------}

globalScope :: IORef (Map ThreadId Scope)
{-# NOINLINE globalScope #-}
globalScope :: IORef (Map ThreadId Scope)
globalScope = IO (IORef (Map ThreadId Scope)) -> IORef (Map ThreadId Scope)
forall a. IO a -> a
unsafePerformIO (IO (IORef (Map ThreadId Scope)) -> IORef (Map ThreadId Scope))
-> IO (IORef (Map ThreadId Scope)) -> IORef (Map ThreadId Scope)
forall a b. (a -> b) -> a -> b
$ Map ThreadId Scope -> IO (IORef (Map ThreadId Scope))
forall a. a -> IO (IORef a)
newIORef Map ThreadId Scope
forall k a. Map k a
Map.empty