{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}

module OpenTelemetry.Instrumentation.Wai (
  newOpenTelemetryWaiMiddleware,
  newOpenTelemetryWaiMiddleware',
  requestContext,
) where

import Control.Exception (bracket)
import Control.Monad
import Data.IP (fromHostAddress, fromHostAddress6)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Vault.Lazy as Vault
import GHC.Stack (HasCallStack, callStack, popCallStack)
import Network.HTTP.Types
import Network.Socket
import Network.Wai
import OpenTelemetry.Attributes (lookupAttribute)
import qualified OpenTelemetry.Context as Context
import OpenTelemetry.Context.ThreadLocal
import OpenTelemetry.Propagator
import OpenTelemetry.Trace.Core
import System.IO.Unsafe


newOpenTelemetryWaiMiddleware :: (HasCallStack) => IO Middleware
newOpenTelemetryWaiMiddleware :: HasCallStack => IO Middleware
newOpenTelemetryWaiMiddleware = HasCallStack => TracerProvider -> Middleware
newOpenTelemetryWaiMiddleware' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m TracerProvider
getGlobalTracerProvider

newOpenTelemetryWaiMiddleware'
  :: (HasCallStack)
  => TracerProvider
  -> Middleware
newOpenTelemetryWaiMiddleware' :: HasCallStack => TracerProvider -> Middleware
newOpenTelemetryWaiMiddleware' TracerProvider
tp =
  let waiTracer :: Tracer
waiTracer =
        TracerProvider -> InstrumentationLibrary -> TracerOptions -> Tracer
makeTracer
          TracerProvider
tp
          InstrumentationLibrary
"opentelemetry-instrumentation-wai"
          (Maybe Text -> TracerOptions
TracerOptions forall a. Maybe a
Nothing)
  in Tracer -> Middleware
middleware Tracer
waiTracer
  where
    usefulCallsite :: HashMap Text Attribute
usefulCallsite = HasCallStack => HashMap Text Attribute
callerAttributes
    middleware :: Tracer -> Middleware
    middleware :: Tracer -> Middleware
middleware Tracer
tracer Application
app Request
req Response -> IO ResponseReceived
sendResp = do
      let propagator :: Propagator Context ResponseHeaders ResponseHeaders
propagator = TracerProvider
-> Propagator Context ResponseHeaders ResponseHeaders
getTracerProviderPropagators forall a b. (a -> b) -> a -> b
$ Tracer -> TracerProvider
getTracerTracerProvider Tracer
tracer
      let parentContextM :: IO (Maybe Context)
parentContextM = do
            Context
ctx <- forall (m :: * -> *). MonadIO m => m Context
getContext
            Context
ctxt <- forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> i -> context -> m context
extract Propagator Context ResponseHeaders ResponseHeaders
propagator (Request -> ResponseHeaders
requestHeaders Request
req) Context
ctx
            forall (m :: * -> *). MonadIO m => Context -> m (Maybe Context)
attachContext Context
ctxt
      let path_ :: Text
path_ = ByteString -> Text
T.decodeUtf8 forall a b. (a -> b) -> a -> b
$ Request -> ByteString
rawPathInfo Request
req
      -- peer = remoteHost req
      IO (Maybe Context)
parentContextM
      forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan'' Tracer
tracer Text
path_ (SpanArguments
defaultSpanArguments {kind :: SpanKind
kind = SpanKind
Server, attributes :: HashMap Text Attribute
attributes = HashMap Text Attribute
usefulCallsite}) forall a b. (a -> b) -> a -> b
$ \Span
requestSpan -> do
        Context
ctxt <- forall (m :: * -> *). MonadIO m => m Context
getContext
        forall (m :: * -> *).
MonadIO m =>
Span -> HashMap Text Attribute -> m ()
addAttributes
          Span
requestSpan
          [ (Text
"http.method", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req)
          , -- , ( "http.url",
            --     toAttribute $
            --     T.decodeUtf8
            --     ((if secure req then "https://" else "http://") <> host req <> ":" <> B.pack (show $ port req) <> path req <> queryString req)
            --   )
            (Text
"http.target", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 (Request -> ByteString
rawPathInfo Request
req forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req))
          , -- , ( "http.host", toAttribute $ T.decodeUtf8 $ host req)
            -- , ( "http.scheme", toAttribute $ TextAttribute $ if secure req then "https" else "http")

            ( Text
"http.flavor"
            , forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ case Request -> HttpVersion
httpVersion Request
req of
                (HttpVersion Int
major Int
minor) -> String -> Text
T.pack (forall a. Show a => a -> String
show Int
major forall a. Semigroup a => a -> a -> a
<> String
"." forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
minor)
            )
          ,
            ( Text
"http.user_agent"
            , forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" ByteString -> Text
T.decodeUtf8 (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hUserAgent forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
requestHeaders Request
req)
            )
          , -- TODO HTTP/3 will require detecting this dynamically
            (Text
"net.transport", forall a. ToAttribute a => a -> Attribute
toAttribute (Text
"ip_tcp" :: T.Text))
          ]

        -- TODO this is warp dependent, probably.
        -- , ( "net.host.ip")
        -- , ( "net.host.port")
        -- , ( "net.host.name")
        forall (m :: * -> *).
MonadIO m =>
Span -> HashMap Text Attribute -> m ()
addAttributes Span
requestSpan forall a b. (a -> b) -> a -> b
$ case Request -> SockAddr
remoteHost Request
req of
          SockAddrInet PortNumber
port HostAddress
addr ->
            [ (Text
"net.peer.port", forall a. ToAttribute a => a -> Attribute
toAttribute (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
            , (Text
"net.peer.ip", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ HostAddress -> IPv4
fromHostAddress HostAddress
addr)
            ]
          SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
addr HostAddress
_ ->
            [ (Text
"net.peer.port", forall a. ToAttribute a => a -> Attribute
toAttribute (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
            , (Text
"net.peer.ip", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ HostAddress6 -> IPv6
fromHostAddress6 HostAddress6
addr)
            ]
          SockAddrUnix String
path ->
            [ (Text
"net.peer.name", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
path)
            ]
        let req' :: Request
req' =
              Request
req
                { vault :: Vault
vault =
                    forall a. Key a -> a -> Vault -> Vault
Vault.insert
                      Key Context
contextKey
                      Context
ctxt
                      (Request -> Vault
vault Request
req)
                }
        Application
app Request
req' forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
          Context
ctxt' <- forall (m :: * -> *). MonadIO m => m Context
getContext
          ResponseHeaders
hs <- forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> context -> o -> m o
inject Propagator Context ResponseHeaders ResponseHeaders
propagator (Span -> Context -> Context
Context.insertSpan Span
requestSpan Context
ctxt') []
          let resp' :: Response
resp' = (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (ResponseHeaders
hs forall a. [a] -> [a] -> [a]
++) Response
resp
          Attributes
attrs <- forall (m :: * -> *). MonadIO m => Span -> m Attributes
spanGetAttributes Span
requestSpan
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Attributes -> Text -> Maybe Attribute
lookupAttribute Attributes
attrs Text
"http.route") forall a b. (a -> b) -> a -> b
$ \case
            AttributeValue (TextAttribute Text
route) -> forall (m :: * -> *). MonadIO m => Span -> Text -> m ()
updateName Span
requestSpan Text
route
            Attribute
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

          forall (m :: * -> *).
MonadIO m =>
Span -> HashMap Text Attribute -> m ()
addAttributes
            Span
requestSpan
            [ (Text
"http.status_code", forall a. ToAttribute a => a -> Attribute
toAttribute forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode forall a b. (a -> b) -> a -> b
$ Response -> Status
responseStatus Response
resp)
            ]
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Status -> Int
statusCode (Response -> Status
responseStatus Response
resp) forall a. Ord a => a -> a -> Bool
>= Int
500) forall a b. (a -> b) -> a -> b
$ do
            forall (m :: * -> *). MonadIO m => Span -> SpanStatus -> m ()
setStatus Span
requestSpan (Text -> SpanStatus
Error Text
"")
          ResponseReceived
respReceived <- Response -> IO ResponseReceived
sendResp Response
resp'
          Timestamp
ts <- forall (m :: * -> *). MonadIO m => m Timestamp
getTimestamp
          forall (m :: * -> *). MonadIO m => Span -> Maybe Timestamp -> m ()
endSpan Span
requestSpan (forall a. a -> Maybe a
Just Timestamp
ts)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ResponseReceived
respReceived


contextKey :: Vault.Key Context.Context
contextKey :: Key Context
contextKey = forall a. IO a -> a
unsafePerformIO forall a. IO (Key a)
Vault.newKey
{-# NOINLINE contextKey #-}


requestContext :: Request -> Maybe Context.Context
requestContext :: Request -> Maybe Context
requestContext =
  forall a. Key a -> Vault -> Maybe a
Vault.lookup Key Context
contextKey
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Vault
vault