{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module OpenTelemetry.Instrumentation.Persistent
  ( wrapSqlBackend
  ) where
import OpenTelemetry.Trace.Core
import OpenTelemetry.Context
import Data.Acquire.Internal
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Database.Persist.Sql
import Database.Persist.SqlBackend (setConnHooks, emptySqlBackendHooks, MkSqlBackendArgs (connRDBMS), getRDBMS, getConnVault, modifyConnVault)
import Database.Persist.SqlBackend.Internal
import Control.Monad.IO.Class
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.Vault.Strict as Vault
import OpenTelemetry.Attributes (Attributes)
import OpenTelemetry.Resource
import UnliftIO.Exception
import OpenTelemetry.Trace.Monad (MonadTracer(..))
import Control.Monad.Reader
import qualified Data.Text as T
import OpenTelemetry.Context.ThreadLocal (getContext, adjustContext)

instance {-# OVERLAPS #-} MonadTracer m => MonadTracer (ReaderT SqlBackend m) where
  getTracer :: ReaderT SqlBackend m Tracer
getTracer = m Tracer -> ReaderT SqlBackend m Tracer
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Tracer
forall (m :: * -> *). MonadTracer m => m Tracer
OpenTelemetry.Trace.Monad.getTracer
instance {-# OVERLAPS #-} MonadTracer m => MonadTracer (ReaderT SqlReadBackend m) where
  getTracer :: ReaderT SqlReadBackend m Tracer
getTracer = m Tracer -> ReaderT SqlReadBackend m Tracer
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Tracer
forall (m :: * -> *). MonadTracer m => m Tracer
OpenTelemetry.Trace.Monad.getTracer
instance {-# OVERLAPS #-} MonadTracer m => MonadTracer (ReaderT SqlWriteBackend m) where
  getTracer :: ReaderT SqlWriteBackend m Tracer
getTracer = m Tracer -> ReaderT SqlWriteBackend m Tracer
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Tracer
forall (m :: * -> *). MonadTracer m => m Tracer
OpenTelemetry.Trace.Monad.getTracer

originalConnectionKey :: Vault.Key SqlBackend
originalConnectionKey :: Key SqlBackend
originalConnectionKey = IO (Key SqlBackend) -> Key SqlBackend
forall a. IO a -> a
unsafePerformIO IO (Key SqlBackend)
forall a. IO (Key a)
Vault.newKey
{-# NOINLINE originalConnectionKey #-}

insertOriginalConnection :: SqlBackend -> SqlBackend -> SqlBackend
insertOriginalConnection :: SqlBackend -> SqlBackend -> SqlBackend
insertOriginalConnection SqlBackend
conn SqlBackend
original = (Vault -> Vault) -> SqlBackend -> SqlBackend
modifyConnVault (Key SqlBackend -> SqlBackend -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
Vault.insert Key SqlBackend
originalConnectionKey SqlBackend
original) SqlBackend
conn

lookupOriginalConnection :: SqlBackend -> Maybe SqlBackend
lookupOriginalConnection :: SqlBackend -> Maybe SqlBackend
lookupOriginalConnection = Key SqlBackend -> Vault -> Maybe SqlBackend
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key SqlBackend
originalConnectionKey (Vault -> Maybe SqlBackend)
-> (SqlBackend -> Vault) -> SqlBackend -> Maybe SqlBackend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlBackend -> Vault
forall backend (m :: * -> *).
(BackendCompatible SqlBackend backend, MonadReader backend m) =>
m Vault
getConnVault

connectionLevelAttributesKey :: Vault.Key [(Text, Attribute)]
connectionLevelAttributesKey :: Key [(Text, Attribute)]
connectionLevelAttributesKey = IO (Key [(Text, Attribute)]) -> Key [(Text, Attribute)]
forall a. IO a -> a
unsafePerformIO IO (Key [(Text, Attribute)])
forall a. IO (Key a)
Vault.newKey
{-# NOINLINE connectionLevelAttributesKey #-}

-- | Wrap a 'SqlBackend' with appropriate tracing context and attributes
-- so that queries are tracked appropriately in the tracing hierarchy.
wrapSqlBackend
  :: MonadIO m
  => [(Text, Attribute)]
  -- ^ Attributes that are specific to providers like MySQL, PostgreSQL, etc.
  -> SqlBackend
  -> m SqlBackend
wrapSqlBackend :: [(Text, Attribute)] -> SqlBackend -> m SqlBackend
wrapSqlBackend [(Text, Attribute)]
attrs SqlBackend
conn_ = do
  TracerProvider
tp <- m TracerProvider
forall (m :: * -> *). MonadIO m => m TracerProvider
getGlobalTracerProvider
  let conn :: SqlBackend
conn = SqlBackend -> Maybe SqlBackend -> SqlBackend
forall a. a -> Maybe a -> a
Data.Maybe.fromMaybe SqlBackend
conn_ (SqlBackend -> Maybe SqlBackend
lookupOriginalConnection SqlBackend
conn_)
  -- TODO add schema to tracerOptions?
  let t :: Tracer
t = TracerProvider -> InstrumentationLibrary -> TracerOptions -> Tracer
makeTracer TracerProvider
tp InstrumentationLibrary
"hs-opentelemetry-persistent" TracerOptions
tracerOptions
  let hooks :: SqlBackendHooks
hooks = SqlBackendHooks
emptySqlBackendHooks
        { hookGetStatement :: SqlBackend -> Text -> Statement -> IO Statement
hookGetStatement = \SqlBackend
conn Text
sql Statement
stmt -> do
            Statement -> IO Statement
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Statement -> IO Statement) -> Statement -> IO Statement
forall a b. (a -> b) -> a -> b
$ Statement :: IO ()
-> IO ()
-> ([PersistValue] -> IO Int64)
-> (forall (m :: * -> *).
    MonadIO m =>
    [PersistValue] -> Acquire (ConduitM () [PersistValue] m ()))
-> Statement
Statement
              { stmtQuery :: forall (m :: * -> *).
MonadIO m =>
[PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
stmtQuery = \[PersistValue]
ps -> do
                  Context
ctxt <- Acquire Context
forall (m :: * -> *). MonadIO m => m Context
getContext
                  let spanCreator :: IO (Maybe Span, Span)
spanCreator = do
                        Span
s <- Tracer -> Context -> Text -> SpanArguments -> IO Span
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Tracer -> Context -> Text -> SpanArguments -> m Span
createSpan
                          Tracer
t
                          Context
ctxt
                          Text
sql
                          (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = (Text
"db.statement", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute Text
sql) (Text, Attribute) -> [(Text, Attribute)] -> [(Text, Attribute)]
forall a. a -> [a] -> [a]
: [(Text, Attribute)]
attrs })
                        (Context -> Context) -> IO ()
forall (m :: * -> *). MonadIO m => (Context -> Context) -> m ()
adjustContext (Span -> Context -> Context
insertSpan Span
s)
                        (Maybe Span, Span) -> IO (Maybe Span, Span)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Context -> Maybe Span
lookupSpan Context
ctxt, Span
s)
                      spanCleanup :: (Maybe Span, Span) -> m ()
spanCleanup (Maybe Span
parent, Span
s) = do
                        Span
s Span -> Maybe Timestamp -> m ()
forall (m :: * -> *). MonadIO m => Span -> Maybe Timestamp -> m ()
`endSpan` Maybe Timestamp
forall a. Maybe a
Nothing
                        (Context -> Context) -> m ()
forall (m :: * -> *). MonadIO m => (Context -> Context) -> m ()
adjustContext ((Context -> Context) -> m ()) -> (Context -> Context) -> m ()
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
                          Context -> (Span -> Context) -> Maybe Span -> Context
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Context
ctx (Span -> Context -> Context
`insertSpan` Context
ctx) Maybe Span
parent

                  (Maybe Span
p, Span
child) <- IO (Maybe Span, Span)
-> ((Maybe Span, Span) -> IO ()) -> Acquire (Maybe Span, Span)
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire IO (Maybe Span, Span)
spanCreator (Maybe Span, Span) -> IO ()
forall (m :: * -> *). MonadIO m => (Maybe Span, Span) -> m ()
spanCleanup

                  Span -> SqlBackend -> Acquire ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
child SqlBackend
conn
                  case Statement
-> [PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
Statement
-> forall (m :: * -> *).
   MonadIO m =>
   [PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
stmtQuery Statement
stmt [PersistValue]
ps of
                    Acquire (forall b. IO b -> IO b)
-> IO (Allocated (ConduitM () [PersistValue] m ()))
stmtQueryAcquireF -> ((forall b. IO b -> IO b)
 -> IO (Allocated (ConduitM () [PersistValue] m ())))
-> Acquire (ConduitM () [PersistValue] m ())
forall a.
((forall b. IO b -> IO b) -> IO (Allocated a)) -> Acquire a
Acquire (((forall b. IO b -> IO b)
  -> IO (Allocated (ConduitM () [PersistValue] m ())))
 -> Acquire (ConduitM () [PersistValue] m ()))
-> ((forall b. IO b -> IO b)
    -> IO (Allocated (ConduitM () [PersistValue] m ())))
-> Acquire (ConduitM () [PersistValue] m ())
forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
f ->
                      (SomeException -> IO (Allocated (ConduitM () [PersistValue] m ())))
-> IO (Allocated (ConduitM () [PersistValue] m ()))
-> IO (Allocated (ConduitM () [PersistValue] m ()))
forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
handleAny
                        (\(SomeException e
err) -> do
                          Span -> [(Text, Attribute)] -> Maybe Timestamp -> e -> IO ()
forall (m :: * -> *) e.
(MonadIO m, Exception e) =>
Span -> [(Text, Attribute)] -> Maybe Timestamp -> e -> m ()
recordException Span
child [] Maybe Timestamp
forall a. Maybe a
Nothing e
err
                          Span -> Maybe Timestamp -> IO ()
forall (m :: * -> *). MonadIO m => Span -> Maybe Timestamp -> m ()
endSpan Span
child Maybe Timestamp
forall a. Maybe a
Nothing
                          e -> IO (Allocated (ConduitM () [PersistValue] m ()))
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO e
err
                        )
                        ((forall b. IO b -> IO b)
-> IO (Allocated (ConduitM () [PersistValue] m ()))
stmtQueryAcquireF forall b. IO b -> IO b
f)

              , stmtExecute :: [PersistValue] -> IO Int64
stmtExecute = \[PersistValue]
ps -> do
                Tracer -> Text -> SpanArguments -> (Span -> IO Int64) -> IO Int64
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
t Text
sql (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = (Text
"db.statement", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute Text
sql) (Text, Attribute) -> [(Text, Attribute)] -> [(Text, Attribute)]
forall a. a -> [a] -> [a]
: [(Text, Attribute)]
attrs }) ((Span -> IO Int64) -> IO Int64) -> (Span -> IO Int64) -> IO Int64
forall a b. (a -> b) -> a -> b
$ \Span
s -> do
                  Span -> SqlBackend -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
s SqlBackend
conn
                  Statement -> [PersistValue] -> IO Int64
stmtExecute Statement
stmt [PersistValue]
ps
              , stmtReset :: IO ()
stmtReset = Statement -> IO ()
stmtReset Statement
stmt
              , stmtFinalize :: IO ()
stmtFinalize = Statement -> IO ()
stmtFinalize Statement
stmt
              }
        }

  let conn' :: SqlBackend
conn' = SqlBackend
conn
        { connHooks :: SqlBackendHooks
connHooks = SqlBackendHooks
hooks
        , connBegin :: (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
connBegin = \Text -> IO Statement
f Maybe IsolationLevel
mIso -> do
            let statement :: Text
statement = Text
"begin transaction" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> case Maybe IsolationLevel
mIso of
                  Maybe IsolationLevel
Nothing -> Text
forall a. Monoid a => a
mempty
                  Just IsolationLevel
ReadUncommitted -> Text
" isolation level read uncommitted"
                  Just IsolationLevel
ReadCommitted -> Text
" isolation level read committed"
                  Just IsolationLevel
RepeatableRead -> Text
" isolation level repeatable read"
                  Just IsolationLevel
Serializable -> Text
" isolation level serializable"
            let attrs' :: [(Text, Attribute)]
attrs' = (Text
"db.statement", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute Text
statement) (Text, Attribute) -> [(Text, Attribute)] -> [(Text, Attribute)]
forall a. a -> [a] -> [a]
: [(Text, Attribute)]
attrs
            Tracer -> Text -> SpanArguments -> (Span -> IO ()) -> IO ()
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
t Text
statement (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = [(Text, Attribute)]
attrs' }) ((Span -> IO ()) -> IO ()) -> (Span -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Span
s -> do
              Span -> SqlBackend -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
s SqlBackend
conn
              SqlBackend
-> (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
connBegin SqlBackend
conn Text -> IO Statement
f Maybe IsolationLevel
mIso
        , connCommit :: (Text -> IO Statement) -> IO ()
connCommit = \Text -> IO Statement
f -> do
            Tracer -> Text -> SpanArguments -> (Span -> IO ()) -> IO ()
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
t Text
"commit" (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = (Text
"db.statement", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text
"commit" :: Text))(Text, Attribute) -> [(Text, Attribute)] -> [(Text, Attribute)]
forall a. a -> [a] -> [a]
: [(Text, Attribute)]
attrs }) ((Span -> IO ()) -> IO ()) -> (Span -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Span
s -> do
              Span -> SqlBackend -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
s SqlBackend
conn
              SqlBackend -> (Text -> IO Statement) -> IO ()
connCommit SqlBackend
conn Text -> IO Statement
f
        , connRollback :: (Text -> IO Statement) -> IO ()
connRollback = \Text -> IO Statement
f -> do
            Tracer -> Text -> SpanArguments -> (Span -> IO ()) -> IO ()
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
t Text
"rollback" (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = (Text
"db.statement", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text
"rollback" :: Text))(Text, Attribute) -> [(Text, Attribute)] -> [(Text, Attribute)]
forall a. a -> [a] -> [a]
: [(Text, Attribute)]
attrs }) ((Span -> IO ()) -> IO ()) -> (Span -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Span
s -> do
              Span -> SqlBackend -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
s SqlBackend
conn
              SqlBackend -> (Text -> IO Statement) -> IO ()
connRollback SqlBackend
conn Text -> IO Statement
f
        , connClose :: IO ()
connClose = do
            Tracer -> Text -> SpanArguments -> (Span -> IO ()) -> IO ()
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
Tracer -> Text -> SpanArguments -> (Span -> m a) -> m a
inSpan' Tracer
t Text
"close connection" (SpanArguments
defaultSpanArguments { kind :: SpanKind
kind = SpanKind
Client, attributes :: [(Text, Attribute)]
attributes = [(Text, Attribute)]
attrs }) ((Span -> IO ()) -> IO ()) -> (Span -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Span
s -> do
              Span -> SqlBackend -> IO ()
forall (m :: * -> *). MonadIO m => Span -> SqlBackend -> m ()
annotateBasics Span
s SqlBackend
conn
              SqlBackend -> IO ()
connClose SqlBackend
conn
        }
  SqlBackend -> m SqlBackend
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SqlBackend -> m SqlBackend) -> SqlBackend -> m SqlBackend
forall a b. (a -> b) -> a -> b
$ SqlBackend -> SqlBackend -> SqlBackend
insertOriginalConnection SqlBackend
conn' SqlBackend
conn

annotateBasics :: MonadIO m => Span -> SqlBackend -> m ()
annotateBasics :: Span -> SqlBackend -> m ()
annotateBasics Span
span SqlBackend
conn = do
  Span -> [(Text, Attribute)] -> m ()
forall (m :: * -> *).
MonadIO m =>
Span -> [(Text, Attribute)] -> m ()
addAttributes Span
span
    [ (Text
"db.system", Text -> Attribute
forall a. ToAttribute a => a -> Attribute
toAttribute (Text -> Attribute) -> Text -> Attribute
forall a b. (a -> b) -> a -> b
$ SqlBackend -> Text
forall backend (m :: * -> *).
(BackendCompatible SqlBackend backend, MonadReader backend m) =>
m Text
getRDBMS SqlBackend
conn)
    ]