{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
module OpenTelemetry.Instrumentation.Wai 
  ( newOpenTelemetryWaiMiddleware
  , newOpenTelemetryWaiMiddleware'
  , requestContext
  ) where

import qualified Data.Vault.Lazy as Vault
import Network.HTTP.Types
import Network.Wai
import qualified OpenTelemetry.Context as Context
import OpenTelemetry.Context.ThreadLocal
import OpenTelemetry.Propagator
import OpenTelemetry.Trace.Core
import System.IO.Unsafe
import qualified Data.Text.Encoding as T
import qualified Data.Text as T
import Control.Monad
import Network.Socket
import Data.IP (fromHostAddress, fromHostAddress6)
import OpenTelemetry.Attributes (lookupAttribute)
import Control.Exception (bracket)

newOpenTelemetryWaiMiddleware :: IO Middleware
newOpenTelemetryWaiMiddleware :: IO Middleware
newOpenTelemetryWaiMiddleware = IO TracerProvider
forall (m :: * -> *). MonadIO m => m TracerProvider
getGlobalTracerProvider IO TracerProvider
-> (TracerProvider -> IO Middleware) -> IO Middleware
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TracerProvider -> IO Middleware
newOpenTelemetryWaiMiddleware'

newOpenTelemetryWaiMiddleware'
  :: TracerProvider 
  -> IO Middleware
newOpenTelemetryWaiMiddleware' :: TracerProvider -> IO Middleware
newOpenTelemetryWaiMiddleware' TracerProvider
tp = do
  Tracer
waiTracer <- TracerProvider
-> InstrumentationLibrary -> TracerOptions -> IO Tracer
forall (m :: * -> *).
MonadIO m =>
TracerProvider
-> InstrumentationLibrary -> TracerOptions -> m Tracer
getTracer 
    TracerProvider
tp
    InstrumentationLibrary
"opentelemetry-instrumentation-wai" 
    (Maybe Text -> TracerOptions
TracerOptions Maybe Text
forall a. Maybe a
Nothing)
  Middleware -> IO Middleware
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Middleware -> IO Middleware) -> Middleware -> IO Middleware
forall a b. (a -> b) -> a -> b
$ Tracer -> Middleware
middleware Tracer
waiTracer
  where
    middleware :: Tracer -> Middleware
    middleware :: Tracer -> Middleware
middleware Tracer
tracer Application
app Request
req Response -> IO ResponseReceived
sendResp = do
      let propagator :: Propagator Context RequestHeaders RequestHeaders
propagator = TracerProvider -> Propagator Context RequestHeaders RequestHeaders
getTracerProviderPropagators (TracerProvider
 -> Propagator Context RequestHeaders RequestHeaders)
-> TracerProvider
-> Propagator Context RequestHeaders RequestHeaders
forall a b. (a -> b) -> a -> b
$ Tracer -> TracerProvider
getTracerTracerProvider Tracer
tracer
      let parentContextM :: IO (Maybe Context)
parentContextM = do
            Context
ctx <- IO Context
forall (m :: * -> *). MonadIO m => m Context
getContext
            Context
ctxt <- Propagator Context RequestHeaders RequestHeaders
-> RequestHeaders -> Context -> IO Context
forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> i -> context -> m context
extract Propagator Context RequestHeaders RequestHeaders
propagator (Request -> RequestHeaders
requestHeaders Request
req) Context
ctx
            Context -> IO (Maybe Context)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe Context)
attachContext Context
ctxt
      let path_ :: Text
path_ = ByteString -> Text
T.decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
rawPathInfo Request
req
          -- peer = remoteHost req
      IO (Maybe Context)
-> (Maybe Context -> IO ())
-> (Maybe Context -> IO ResponseReceived)
-> IO ResponseReceived
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket 
        IO (Maybe Context)
parentContextM
        (\case
          Maybe Context
Nothing -> IO (Maybe Context) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void IO (Maybe Context)
forall (m :: * -> *). MonadIO m => m (Maybe Context)
detachContext
          Just Context
p -> IO (Maybe Context) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Context -> IO (Maybe Context)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe Context)
attachContext Context
p)
        )
        ((Maybe Context -> IO ResponseReceived) -> IO ResponseReceived)
-> (Maybe Context -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Maybe Context
_ -> do
          Tracer
-> Text
-> SpanArguments
-> (Span -> IO ResponseReceived)
-> IO ResponseReceived
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
tracer Text
path_ (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Server }) ((Span -> IO ResponseReceived) -> IO ResponseReceived)
-> (Span -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Span
requestSpan -> do
            Context
ctxt <- IO Context
forall (m :: * -> *). MonadIO m => m Context
getContext
            Span -> [(Text, Attribute)] -> IO ()
forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes Span
requestSpan
              [ ( Text
"http.method", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req)
              -- , ( "http.url",
              --     toAttribute $
              --     T.decodeUtf8
              --     ((if secure req then "https://" else "http://") <> host req <> ":" <> B.pack (show $ port req) <> path req <> queryString req)
              --   )
              , ( Text
"http.target", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
T.decodeUtf8 (Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req))
              -- , ( "http.host", toAttribute $ T.decodeUtf8 $ host req)
              -- , ( "http.scheme", toAttribute $ TextAttribute $ if secure req then "https" else "http")
              , ( Text
"http.flavor"
                , Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ case Request -> HttpVersion
httpVersion Request
req of
                    (HttpVersion Int
major Int
minor) -> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
major String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
minor)
                )
              , ( Text
"http.user_agent"
                , Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ Text -> (ByteString -> Text) -> Maybe ByteString -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" ByteString -> Text
T.decodeUtf8 (HeaderName -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hUserAgent (RequestHeaders -> Maybe ByteString)
-> RequestHeaders -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req)
                )
              -- TODO HTTP/3 will require detecting this dynamically
              , ( Text
"net.transport", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text
"ip_tcp" :: T.Text))
              ]

            -- TODO this is warp dependent, probably.
            -- , ( "net.host.ip")
            -- , ( "net.host.port")
            -- , ( "net.host.name")
            Span -> [(Text, Attribute)] -> IO ()
forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes Span
requestSpan ([(Text, Attribute)] -> IO ()) -> [(Text, Attribute)] -> IO ()
forall a b. (a -> b) -> a -> b
$ case Request -> SockAddr
remoteHost Request
req of
              SockAddrInet PortNumber
port HostAddress
addr ->
                [ (Text
"net.peer.port", Int -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
                , (Text
"net.peer.ip", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ IPv4 -> String
forall a. Show a => a -> String
show (IPv4 -> String) -> IPv4 -> String
forall a b. (a -> b) -> a -> b
$ HostAddress -> IPv4
fromHostAddress HostAddress
addr)
                ]
              SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
addr HostAddress
_ ->
                [ (Text
"net.peer.port", Int -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port :: Int))
                , (Text
"net.peer.ip", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ IPv6 -> String
forall a. Show a => a -> String
show (IPv6 -> String) -> IPv6 -> String
forall a b. (a -> b) -> a -> b
$ HostAddress6 -> IPv6
fromHostAddress6 HostAddress6
addr)
                ]
              SockAddrUnix String
path ->
                [ (Text
"net.peer.name", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
path)
                ]
            let req' :: Request
req' = Request
req 
                  { vault :: Vault
vault = Key Context -> Context -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
Vault.insert 
                      Key Context
contextKey 
                      Context
ctxt
                      (Request -> Vault
vault Request
req) 
                  }
            Application
app Request
req' ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
              Context
ctxt' <- IO Context
forall (m :: * -> *). MonadIO m => m Context
getContext
              RequestHeaders
hs <- Propagator Context RequestHeaders RequestHeaders
-> Context -> RequestHeaders -> IO RequestHeaders
forall (m :: * -> *) context i o.
MonadIO m =>
Propagator context i o -> context -> o -> m o
inject Propagator Context RequestHeaders RequestHeaders
propagator (Span -> Context -> Context
Context.insertSpan Span
requestSpan Context
ctxt') []
              let resp' :: Response
resp' = (RequestHeaders -> RequestHeaders) -> Response -> Response
mapResponseHeaders (RequestHeaders
hs RequestHeaders -> RequestHeaders -> RequestHeaders
forall a. [a] -> [a] -> [a]
++) Response
resp
              Attributes
attrs <- Span -> IO Attributes
forall (m :: * -> *). MonadIO m => Span -> m Attributes
spanGetAttributes Span
requestSpan
              Maybe Attribute -> (Attribute -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Attributes -> Text -> Maybe Attribute
lookupAttribute Attributes
attrs Text
"http.route") ((Attribute -> IO ()) -> IO ()) -> (Attribute -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \case
                AttributeValue (TextAttribute Text
route) -> Span -> Text -> IO ()
forall (m :: * -> *). MonadIO m => Span -> Text -> m ()
updateName Span
requestSpan Text
route 
                Attribute
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

              Span -> [(Text, Attribute)] -> IO ()
forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes Span
requestSpan
                [ ( Text
"http.status_code", Int -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Int -> Attribute) -> Int -> Attribute
forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode (Status -> Int) -> Status -> Int
forall a b. (a -> b) -> a -> b
$ Response -> Status
responseStatus Response
resp)
                ]
              Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Status -> Int
statusCode (Response -> Status
responseStatus Response
resp) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
500) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Span -> SpanStatus -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SpanStatus -> m ()
setStatus Span
requestSpan (Text -> SpanStatus
Error Text
"")
              ResponseReceived
respReceived <- Response -> IO ResponseReceived
sendResp Response
resp'
              Timestamp
ts <- IO Timestamp
forall (m :: * -> *). MonadIO m => m Timestamp
getTimestamp
              Span -> Maybe Timestamp -> IO ()
forall (m :: * -> *). MonadIO m => Span -> Maybe Timestamp -> m ()
endSpan Span
requestSpan (Timestamp -> Maybe Timestamp
forall a. a -> Maybe a
Just Timestamp
ts)
              ResponseReceived -> IO ResponseReceived
forall (f :: * -> *) a. Applicative f => a -> f a
pure ResponseReceived
respReceived

contextKey :: Vault.Key Context.Context
contextKey :: Key Context
contextKey = IO (Key Context) -> Key Context
forall a. IO a -> a
unsafePerformIO IO (Key Context)
forall a. IO (Key a)
Vault.newKey
{-# NOINLINE contextKey #-}

requestContext :: Request -> Maybe Context.Context
requestContext :: Request -> Maybe Context
requestContext = 
  Key Context -> Vault -> Maybe Context
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key Context
contextKey (Vault -> Maybe Context)
-> (Request -> Vault) -> Request -> Maybe Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. 
  Request -> Vault
vault