{-# LANGUAGE NamedFieldPuns #-}

-- |
--
-- <https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader>
--
-- Example:
--
-- @Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1@
--
-- @Root=1-{epoch}-{unique}[;Parent={spanId}][;Sampled={1|0}][;meta=attr...]@
--
module OpenTelemetry.AWSXRay.TraceInfo
  ( TraceInfo(..)
  , fromXRayHeader
  , toXRayHeader
  ) where

import Prelude

import Control.Error.Util (note)
import Data.Bifunctor (first)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified OpenTelemetry.AWSXRay.Baggage as Baggage
import OpenTelemetry.Baggage (Baggage)
import OpenTelemetry.Trace (SpanContext(..))
import OpenTelemetry.Trace.Core
  (defaultTraceFlags, isSampled, setSampled, unsetSampled)
import OpenTelemetry.Trace.Id
  ( Base(..)
  , baseEncodedToSpanId
  , baseEncodedToTraceId
  , spanIdBaseEncodedByteString
  , traceIdBaseEncodedByteString
  )
import qualified OpenTelemetry.Trace.TraceState as TS

-- | The data to read/write from the @X-Amzn-TraceId@ header
data TraceInfo = TraceInfo
  { TraceInfo -> SpanContext
spanContext :: SpanContext
  , TraceInfo -> Maybe Baggage
baggage :: Maybe Baggage
  }
  deriving stock Int -> TraceInfo -> ShowS
[TraceInfo] -> ShowS
TraceInfo -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceInfo] -> ShowS
$cshowList :: [TraceInfo] -> ShowS
show :: TraceInfo -> String
$cshow :: TraceInfo -> String
showsPrec :: Int -> TraceInfo -> ShowS
$cshowsPrec :: Int -> TraceInfo -> ShowS
Show

fromXRayHeader :: ByteString -> Either String TraceInfo
fromXRayHeader :: ByteString -> Either String TraceInfo
fromXRayHeader ByteString
bs = do
  [(ByteString, ByteString)]
kv <- ByteString -> Either String [(ByteString, ByteString)]
bsToKeyValues ByteString
bs

  ByteString
root <- forall a b. a -> Maybe b -> Either a b
note String
"Root not present" forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"Root" [(ByteString, ByteString)]
kv
  TraceId
traceId <- case Char -> ByteString -> [ByteString]
bsSplitOn Char
'-' ByteString
root of
    [ByteString
"1", ByteString
epoch, ByteString
unique] -> do
      let
        -- AWS may trim leading zeros from epoch; we must put them back for it
        -- to be valid for OTel
        epochUnique :: ByteString
epochUnique = Int -> Char -> ByteString -> ByteString
bsLeftPad Int
8 Char
'0' ByteString
epoch forall a. Semigroup a => a -> a -> a
<> ByteString
unique
        errorPrefix :: String
errorPrefix =
          String
"Root epoch+unique ("
            forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show ByteString
epochUnique
            forall a. Semigroup a => a -> a -> a
<> String
") is not a valid TraceId"
      forall a. String -> Either String a -> Either String a
prefix String
errorPrefix forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String TraceId
baseEncodedToTraceId Base
Base16 ByteString
epochUnique
    [ByteString]
_ -> forall a b. a -> Either a b
Left String
"Splitting on - did not produce exactly 3 parts"

  ByteString
parent <- forall a b. a -> Maybe b -> Either a b
note String
"Parent not present" forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"Parent" [(ByteString, ByteString)]
kv
  SpanId
spanId <- forall a. String -> Either String a -> Either String a
prefix String
"Parent is not a valid SpanId"
    forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String SpanId
baseEncodedToSpanId Base
Base16 ByteString
parent

  let
    sampled :: Bool
sampled = (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just ByteString
"1") forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"Sampled" [(ByteString, ByteString)]
kv

    traceFlags :: TraceFlags
traceFlags =
      (if Bool
sampled then TraceFlags -> TraceFlags
setSampled else TraceFlags -> TraceFlags
unsetSampled) TraceFlags
defaultTraceFlags

    baggage :: Maybe Baggage
baggage = [(ByteString, ByteString)] -> Maybe Baggage
Baggage.decode [(ByteString, ByteString)]
kv

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TraceInfo
    { spanContext :: SpanContext
spanContext = SpanContext
      { TraceFlags
traceFlags :: TraceFlags
traceFlags :: TraceFlags
traceFlags
      , TraceId
traceId :: TraceId
traceId :: TraceId
traceId
      , SpanId
spanId :: SpanId
spanId :: SpanId
spanId
      , isRemote :: Bool
isRemote = Bool
True
      , traceState :: TraceState
traceState = TraceState
TS.empty
      }
    , Maybe Baggage
baggage :: Maybe Baggage
baggage :: Maybe Baggage
baggage
    }

toXRayHeader :: TraceInfo -> ByteString
toXRayHeader :: TraceInfo -> ByteString
toXRayHeader TraceInfo { SpanContext
spanContext :: SpanContext
spanContext :: TraceInfo -> SpanContext
spanContext, Maybe Baggage
baggage :: Maybe Baggage
baggage :: TraceInfo -> Maybe Baggage
baggage } =
  [(ByteString, ByteString)] -> ByteString
bsFromKeyValues
    forall a b. (a -> b) -> a -> b
$ [ (ByteString
"Root", ByteString
"1-" forall a. Semigroup a => a -> a -> a
<> ByteString
epoch forall a. Semigroup a => a -> a -> a
<> ByteString
"-" forall a. Semigroup a => a -> a -> a
<> ByteString
unique)
      , (ByteString
"Parent", Base -> SpanId -> ByteString
spanIdBaseEncodedByteString Base
Base16 SpanId
spanId)
      , (ByteString
"Sampled", if TraceFlags -> Bool
isSampled TraceFlags
traceFlags then ByteString
"1" else ByteString
"0")
      ]
    forall a. Semigroup a => a -> a -> a
<> forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] Baggage -> [(ByteString, ByteString)]
Baggage.encode Maybe Baggage
baggage
 where
  SpanContext { TraceId
traceId :: TraceId
traceId :: SpanContext -> TraceId
traceId, SpanId
spanId :: SpanId
spanId :: SpanContext -> SpanId
spanId, TraceFlags
traceFlags :: TraceFlags
traceFlags :: SpanContext -> TraceFlags
traceFlags } = SpanContext
spanContext
  (ByteString
epoch, ByteString
unique) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
8 forall a b. (a -> b) -> a -> b
$ Base -> TraceId -> ByteString
traceIdBaseEncodedByteString Base
Base16 TraceId
traceId

prefix :: String -> Either String a -> Either String a
prefix :: forall a. String -> Either String a -> Either String a
prefix String
p = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\String
e -> String
p forall a. Semigroup a => a -> a -> a
<> String
": " forall a. Semigroup a => a -> a -> a
<> String
e)

bsToKeyValues :: ByteString -> Either String [(ByteString, ByteString)]
bsToKeyValues :: ByteString -> Either String [(ByteString, ByteString)]
bsToKeyValues = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {a}.
IsString a =>
ByteString -> Either a (ByteString, ByteString)
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
bsSplitOn Char
';'
 where
  go :: ByteString -> Either a (ByteString, ByteString)
go ByteString
bs = case Char -> ByteString -> [ByteString]
bsSplitOn Char
'=' ByteString
bs of
    ByteString
k : [ByteString]
vs -> forall a b. b -> Either a b
Right (ByteString
k, forall a. Monoid a => [a] -> a
mconcat [ByteString]
vs)
    [ByteString]
_ -> forall a b. a -> Either a b
Left a
"No = found in key-value piece"

bsFromKeyValues :: [(ByteString, ByteString)] -> ByteString
bsFromKeyValues :: [(ByteString, ByteString)] -> ByteString
bsFromKeyValues = ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
";" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\(ByteString
k, ByteString
v) -> ByteString
k forall a. Semigroup a => a -> a -> a
<> ByteString
"=" forall a. Semigroup a => a -> a -> a
<> ByteString
v)

bsLeftPad :: Int -> Char -> ByteString -> ByteString
bsLeftPad :: Int -> Char -> ByteString -> ByteString
bsLeftPad Int
n Char
c ByteString
bs
  | Int
diff forall a. Ord a => a -> a -> Bool
> Int
0 = Int -> Char -> ByteString
BS8.replicate Int
diff Char
c forall a. Semigroup a => a -> a -> a
<> ByteString
bs
  | Bool
otherwise = ByteString
bs
  where diff :: Int
diff = ByteString -> Int
BS.length ByteString
bs forall a. Num a => a -> a -> a
- Int
n

bsSplitOn :: Char -> ByteString -> [ByteString]
bsSplitOn :: Char -> ByteString -> [ByteString]
bsSplitOn Char
c = (Char -> Bool) -> ByteString -> [ByteString]
BS8.splitWith (forall a. Eq a => a -> a -> Bool
== Char
c)