{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE StrictData #-}
module Context.Internal
(
Store(Store, ref, key)
, State(State, stacks, def)
, NotFoundException(NotFoundException, threadId)
, withStore
, newStore
, use
, push
, pop
, mineMay
, mineMayOnDefault
, setDefault
, throwContextNotFound
, View(MkView)
, view
, viewMay
, toView
, PropagationStrategy(NoPropagation, LatestPropagation)
, Registry(Registry, ref)
, AnyStore(MkAnyStore)
, registry
, emptyRegistry
, withPropagator
, withRegisteredPropagator
, register
, unregister
, bug
) where
import Control.Concurrent (ThreadId)
import Control.Exception (Exception)
import Control.Monad ((<=<))
import Data.IORef (IORef)
import Data.Map.Strict (Map)
import Data.Unique (Unique)
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack)
import Prelude
import System.IO.Unsafe (unsafePerformIO)
import qualified Control.Concurrent as Concurrent
import qualified Control.Exception as Exception
import qualified Data.IORef as IORef
import qualified Data.Map.Strict as Map
import qualified Data.Traversable as Traversable
import qualified Data.Unique as Unique
data Store ctx = Store
{ ref :: IORef (State ctx)
, key :: Unique
}
data State ctx = State
{ stacks :: Map ThreadId [ctx]
, def :: Maybe ctx
}
data NotFoundException = NotFoundException
{ threadId :: ThreadId
} deriving stock (Eq, Generic, Show)
deriving anyclass Exception
data PropagationStrategy
= NoPropagation
| LatestPropagation
setDefault :: Store ctx -> ctx -> IO ()
setDefault Store { ref } context = do
IORef.atomicModifyIORef' ref \state ->
(state { def = Just context }, ())
throwContextNotFound :: IO a
throwContextNotFound = do
threadId <- Concurrent.myThreadId
Exception.throwIO $ NotFoundException { threadId }
mineMay :: Store ctx -> IO (Maybe ctx)
mineMay = mineMayOnDefault id
mineMayOnDefault :: (Maybe ctx -> Maybe ctx) -> Store ctx -> IO (Maybe ctx)
mineMayOnDefault onDefault Store { ref } = do
threadId <- Concurrent.myThreadId
State { stacks, def } <- IORef.readIORef ref
pure
case Map.lookup threadId stacks of
Nothing -> onDefault def
Just [] -> bug "mineMayOnDefault"
Just (context : _rest) -> Just context
use :: Store ctx -> ctx -> IO a -> IO a
use store context = Exception.bracket_ (push store context) (pop store)
withStore
:: PropagationStrategy
-> Maybe ctx
-> (Store ctx -> IO a)
-> IO a
withStore propagationStrategy mContext f = do
store <- newStore propagationStrategy mContext
Exception.finally (f store) do
case propagationStrategy of
NoPropagation -> pure ()
LatestPropagation -> unregister registry store
newStore
:: PropagationStrategy
-> Maybe ctx
-> IO (Store ctx)
newStore propagationStrategy def = do
key <- Unique.newUnique
ref <- IORef.newIORef State { stacks = Map.empty, def }
let store = Store { ref, key }
case propagationStrategy of
NoPropagation -> pure ()
LatestPropagation -> register registry store
pure store
push :: Store ctx -> ctx -> IO ()
push Store { ref } context = do
threadId <- Concurrent.myThreadId
IORef.atomicModifyIORef' ref \state@State { stacks } ->
case Map.lookup threadId stacks of
Nothing ->
(state { stacks = Map.insert threadId [context] stacks }, ())
Just contexts ->
(state { stacks = Map.insert threadId (context : contexts) stacks}, ())
pop :: Store ctx -> IO ()
pop Store { ref } = do
threadId <- Concurrent.myThreadId
IORef.atomicModifyIORef' ref \state@State { stacks } ->
case Map.lookup threadId stacks of
Nothing -> bug "pop-1"
Just [] -> bug "pop-2"
Just [_context] ->
(state { stacks = Map.delete threadId stacks }, ())
Just (_context : rest) ->
(state { stacks = Map.insert threadId rest stacks }, ())
data View ctx where
MkView :: (ctx' -> ctx) -> Store ctx' -> View ctx
instance Functor View where
fmap g (MkView f store) = MkView (g . f) store
view :: View ctx -> IO ctx
view = maybe throwContextNotFound pure <=< viewMay
viewMay :: View ctx -> IO (Maybe ctx)
viewMay = \case
MkView f store -> fmap (fmap f) $ mineMay store
toView :: Store ctx -> View ctx
toView = MkView id
data AnyStore where
MkAnyStore :: forall ctx. Store ctx -> AnyStore
newtype Registry = Registry
{ ref :: IORef (Map Unique AnyStore)
}
registry :: Registry
registry = unsafePerformIO emptyRegistry
{-# NOINLINE registry #-}
emptyRegistry :: IO Registry
emptyRegistry = do
ref <- IORef.newIORef Map.empty
pure Registry { ref }
withPropagator :: ((IO a -> IO a) -> IO b) -> IO b
withPropagator = withRegisteredPropagator registry
withRegisteredPropagator :: Registry -> ((IO a -> IO a) -> IO b) -> IO b
withRegisteredPropagator Registry { ref } f = do
stores <- IORef.readIORef ref
propagator <- do
fmap (foldr (.) id) do
Traversable.for stores \case
MkAnyStore store -> do
mineMayOnDefault (const Nothing) store >>= \case
Nothing -> pure id
Just context -> pure $ use store context
f propagator
register :: Registry -> Store ctx -> IO ()
register Registry { ref } store@Store { key } = do
IORef.atomicModifyIORef' ref \stores ->
(Map.insert key (MkAnyStore store) stores, ())
unregister :: Registry -> Store ctx -> IO ()
unregister Registry { ref } Store { key } = do
IORef.atomicModifyIORef' ref \stores ->
(Map.delete key stores, ())
bug :: HasCallStack => String -> a
bug prefix =
error
$ "Context." <> prefix <> ": Impossible! (if you see this message, please "
<> "report it as a bug at https://github.com/jship/context)"