{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module OpenTelemetry.Instrumentation.Wai (
  newOpenTelemetryWaiMiddleware,
  newOpenTelemetryWaiMiddleware',
  requestContext,
) where

import Control.Exception (bracket)
import Control.Monad
import Data.IP (fromHostAddress, fromHostAddress6)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Vault.Lazy as Vault
import Network.HTTP.Types
import Network.Socket
import Network.Wai
import OpenTelemetry.Attributes (lookupAttribute)
import qualified OpenTelemetry.Context as Context
import OpenTelemetry.Context.ThreadLocal
import OpenTelemetry.Propagator
import OpenTelemetry.Trace.Core
import System.IO.Unsafe


newOpenTelemetryWaiMiddleware :: IO Middleware
newOpenTelemetryWaiMiddleware :: IO Middleware
newOpenTelemetryWaiMiddleware = forall (m :: * -> *). MonadIO m => m TracerProvider
getGlobalTracerProvider forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TracerProvider -> IO Middleware
newOpenTelemetryWaiMiddleware'


newOpenTelemetryWaiMiddleware'
  :: TracerProvider
  -> IO Middleware
newOpenTelemetryWaiMiddleware' :: TracerProvider -> IO Middleware
newOpenTelemetryWaiMiddleware' TracerProvider
tp = do
  let waiTracer :: Tracer
waiTracer =
        TracerProvider -> InstrumentationLibrary -> TracerOptions -> Tracer
makeTracer
          TracerProvider
tp
          InstrumentationLibrary
"opentelemetry-instrumentation-wai"
          (Maybe Text -> TracerOptions
TracerOptions forall a. Maybe a
Nothing)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Tracer -> Middleware
middleware Tracer
waiTracer
  where
    middleware :: Tracer -> Middleware
    middleware :: Tracer -> Middleware
middleware Tracer
tracer Application
app Request
req Response -> IO ResponseReceived
sendResp = do
      let propagator :: Propagator Context RequestHeaders RequestHeaders
propagator = TracerProvider -> Propagator Context RequestHeaders RequestHeaders
getTracerProviderPropagators forall a b. (a -> b) -> a -> b
$ Tracer -> TracerProvider
getTracerTracerProvider Tracer
tracer
      let parentContextM :: IO (Maybe Context)
parentContextM = do
            Context
ctx <- forall (m :: * -> *). MonadIO m => m Context
getContext
            Context
ctxt <- forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> i -> context -> m context
extract Propagator Context RequestHeaders RequestHeaders
propagator (Request -> RequestHeaders
requestHeaders Request
req) Context
ctx
            forall (m :: * -> *). MonadIO m => Context -> m (Maybe Context)
attachContext Context
ctxt
      let path_ :: Text
path_ = ByteString -> Text
T.decodeUtf8 forall a b. (a -> b) -> a -> b
$ Request -> ByteString
rawPathInfo Request
req
      -- peer = remoteHost req
      IO (Maybe Context)
parentContextM
      forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
tracer Text
path_ (SpanArguments
defaultSpanArguments {kind :: SpanKind
kind = SpanKind
Server}) forall a b. (a -> b) -> a -> b
$ \Span
requestSpan -> do
        Context
ctxt <- forall (m :: * -> *). MonadIO m => m Context
getContext
        forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes
          Span
requestSpan
          [ (Text
"http.method", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req)
          , -- , ( "http.url",
            --     toAttribute $
            --     T.decodeUtf8
            --     ((if secure req then "https://" else "http://") <> host req <> ":" <> B.pack (show $ port req) <> path req <> queryString req)
            --   )
            (Text
"http.target", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 (Request -> ByteString
rawPathInfo Request
req forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req))
          , -- , ( "http.host", toAttribute $ T.decodeUtf8 $ host req)
            -- , ( "http.scheme", toAttribute $ TextAttribute $ if secure req then "https" else "http")

            ( Text
"http.flavor"
            , forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ case Request -> HttpVersion
httpVersion Request
req of
                (HttpVersion Int
major Int
minor) -> String -> Text
T.pack (forall a. Show a => a -> String
show Int
major forall a. Semigroup a => a -> a -> a
<> String
"." forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
minor)
            )
          ,
            ( Text
"http.user_agent"
            , forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" ByteString -> Text
T.decodeUtf8 (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hUserAgent forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req)
            )
          , -- TODO HTTP/3 will require detecting this dynamically
            (Text
"net.transport", forall a. ToAttribute a => a -> Attribute
toAttribute (Text
"ip_tcp" :: T.Text))
          ]

        -- TODO this is warp dependent, probably.
        -- , ( "net.host.ip")
        -- , ( "net.host.port")
        -- , ( "net.host.name")
        forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes Span
requestSpan forall a b. (a -> b) -> a -> b
$ case Request -> SockAddr
remoteHost Request
req of
          SockAddrInet PortNumber
port HostAddress
addr ->
            [ (Text
"net.peer.port", forall a. ToAttribute a => a -> Attribute
toAttribute (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
            , (Text
"net.peer.ip", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ HostAddress -> IPv4
fromHostAddress HostAddress
addr)
            ]
          SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
addr HostAddress
_ ->
            [ (Text
"net.peer.port", forall a. ToAttribute a => a -> Attribute
toAttribute (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
            , (Text
"net.peer.ip", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ HostAddress6 -> IPv6
fromHostAddress6 HostAddress6
addr)
            ]
          SockAddrUnix String
path ->
            [ (Text
"net.peer.name", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
path)
            ]
        let req' :: Request
req' =
              Request
req
                { vault :: Vault
vault =
                    forall a. Key a -> a -> Vault -> Vault
Vault.insert
                      Key Context
contextKey
                      Context
ctxt
                      (Request -> Vault
vault Request
req)
                }
        Application
app Request
req' forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
          Context
ctxt' <- forall (m :: * -> *). MonadIO m => m Context
getContext
          RequestHeaders
hs <- forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> context -> o -> m o
inject Propagator Context RequestHeaders RequestHeaders
propagator (Span -> Context -> Context
Context.insertSpan Span
requestSpan Context
ctxt') []
          let resp' :: Response
resp' = (RequestHeaders -> RequestHeaders) -> Response -> Response
mapResponseHeaders (RequestHeaders
hs forall a. [a] -> [a] -> [a]
++) Response
resp
          Attributes
attrs <- forall (m :: * -> *). MonadIO m => Span -> m Attributes
spanGetAttributes Span
requestSpan
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Attributes -> Text -> Maybe Attribute
lookupAttribute Attributes
attrs Text
"http.route") forall a b. (a -> b) -> a -> b
$ \case
            AttributeValue (TextAttribute Text
route) -> forall (m :: * -> *). MonadIO m => Span -> Text -> m ()
updateName Span
requestSpan Text
route
            Attribute
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

          forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes
            Span
requestSpan
            [ (Text
"http.status_code", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode forall a b. (a -> b) -> a -> b
$ Response -> Status
responseStatus Response
resp)
            ]
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Status -> Int
statusCode (Response -> Status
responseStatus Response
resp) forall a. Ord a => a -> a -> Bool
>= Int
500) forall a b. (a -> b) -> a -> b
$ do
            forall (m :: * -> *). MonadIO m => Span -> SpanStatus -> m ()
setStatus Span
requestSpan (Text -> SpanStatus
Error Text
"")
          ResponseReceived
respReceived <- Response -> IO ResponseReceived
sendResp Response
resp'
          Timestamp
ts <- forall (m :: * -> *). MonadIO m => m Timestamp
getTimestamp
          forall (m :: * -> *). MonadIO m => Span -> Maybe Timestamp -> m ()
endSpan Span
requestSpan (forall a. a -> Maybe a
Just Timestamp
ts)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ResponseReceived
respReceived


contextKey :: Vault.Key Context.Context
contextKey :: Key Context
contextKey = forall a. IO a -> a
unsafePerformIO forall a. IO (Key a)
Vault.newKey
{-# NOINLINE contextKey #-}


requestContext :: Request -> Maybe Context.Context
requestContext :: Request -> Maybe Context
requestContext =
  forall a. Key a -> Vault -> Maybe a
Vault.lookup Key Context
contextKey
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Vault
vault