{-# LANGUAGE NamedFieldPuns #-}

module Freckle.App.OpenTelemetry.ThreadContext
  ( withTraceContext
  ) where

import Freckle.App.Prelude

import Blammo.Logging (MonadMask, withThreadContext)
import Data.Aeson ((.=))
import qualified Data.Aeson.Key as Key
import Data.Aeson.Types (Pair)
import Freckle.App.OpenTelemetry (getCurrentSpanContext)
import OpenTelemetry.Trace.Core (SpanContext (..))
import OpenTelemetry.Trace.Id
  ( Base (..)
  , spanIdBaseEncodedText
  , traceIdBaseEncodedText
  )
import OpenTelemetry.Trace.TraceState (Key (..), TraceState (..), Value (..))

-- | Add tracing context values to the logging context
--
-- Values are encoded to 'Base16' (i.e. hex).
withTraceContext :: (MonadIO m, MonadMask m) => m a -> m a
withTraceContext :: forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a
withTraceContext m a
f = do
  Maybe SpanContext
mSpanContext <- m (Maybe SpanContext)
forall (m :: * -> *). MonadIO m => m (Maybe SpanContext)
getCurrentSpanContext

  case Maybe SpanContext
mSpanContext of
    Maybe SpanContext
Nothing -> m a
f
    Just SpanContext
sc -> [Pair] -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Pair] -> m a -> m a
withThreadContext (SpanContext -> [Pair]
toThreadContext SpanContext
sc) m a
f

toThreadContext :: SpanContext -> [Pair]
toThreadContext :: SpanContext -> [Pair]
toThreadContext SpanContext {TraceId
traceId :: TraceId
traceId :: SpanContext -> TraceId
traceId, SpanId
spanId :: SpanId
spanId :: SpanContext -> SpanId
spanId, TraceState
traceState :: TraceState
traceState :: SpanContext -> TraceState
traceState} =
  [ Key
"trace_id" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Base -> TraceId -> Text
traceIdBaseEncodedText Base
Base16 TraceId
traceId
  , Key
"span_id" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Base -> SpanId -> Text
spanIdBaseEncodedText Base
Base16 SpanId
spanId
  , Key
"trace_state" Key -> [Pair] -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= [Pair]
traceStatePairs
  ]
 where
  traceStatePairs :: [Pair]
  traceStatePairs :: [Pair]
traceStatePairs = ((Key, Value) -> Pair) -> [(Key, Value)] -> [Pair]
forall a b. (a -> b) -> [a] -> [b]
map (Key, Value) -> Pair
traceStatePair ([(Key, Value)] -> [Pair]) -> [(Key, Value)] -> [Pair]
forall a b. (a -> b) -> a -> b
$ TraceState -> [(Key, Value)]
unTraceState TraceState
traceState

  traceStatePair :: (Key, Value) -> Pair
  traceStatePair :: (Key, Value) -> Pair
traceStatePair = (Key -> Text -> Pair) -> (Key, Text) -> Pair
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
(.=) ((Key, Text) -> Pair)
-> ((Key, Value) -> (Key, Text)) -> (Key, Value) -> Pair
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Key -> Key) -> (Value -> Text) -> (Key, Value) -> (Key, Text)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (Text -> Key
Key.fromText (Text -> Key) -> (Key -> Text) -> Key -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key -> Text
unKey) Value -> Text
unValue

unTraceState :: TraceState -> [(Key, Value)]
unTraceState :: TraceState -> [(Key, Value)]
unTraceState (TraceState [(Key, Value)]
x) = [(Key, Value)]
x

unKey :: Key -> Text
unKey :: Key -> Text
unKey (Key Text
x) = Text
x

unValue :: Value -> Text
unValue :: Value -> Text
unValue (Value Text
x) = Text
x