{-# LANGUAGE CPP #-}

module Network.AWS.XRayClient.Persistent
  ( xraySqlBackend
  ) where

import Prelude

import Conduit
import Control.Lens
import Control.Monad (void)
import Data.Foldable (for_)
import Data.IORef
import qualified Data.Map as Map
import Data.Text (Text)
import Data.Time.Clock.POSIX
import Database.Persist
import Database.Persist.Sql
import Database.Persist.Sql.Types.Internal
#if MIN_VERSION_persistent(2,13,3)
import Database.Persist.SqlBackend.StatementCache
  (StatementCache, mkSimpleStatementCache, mkStatementCache)
#endif
import Network.AWS.XRayClient.Segment
import Network.AWS.XRayClient.TraceId
import System.Random
import System.Random.XRayCustom

-- | Modify a SqlBackend to send trace data to X-Ray.
--
-- >>> runSqlConn sql (xraySqlBackend "my-query" sendToDaemon backend)
xraySqlBackend
  :: (IsPersistBackend backend, BaseBackend backend ~ SqlBackend)
  => (XRaySegment -> IO ())
  -> IORef StdGen
  -> Text
  -> backend
  -> IO backend
xraySqlBackend :: (XRaySegment -> IO ())
-> IORef StdGen -> Text -> backend -> IO backend
xraySqlBackend XRaySegment -> IO ()
sendTrace IORef StdGen
stdGenIORef Text
subsegmentName =
  (SqlBackend -> backend) -> IO SqlBackend -> IO backend
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SqlBackend -> backend
forall backend.
IsPersistBackend backend =>
BaseBackend backend -> backend
mkPersistBackend (IO SqlBackend -> IO backend)
-> (backend -> IO SqlBackend) -> backend -> IO backend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlBackend -> IO SqlBackend
modifyBackend (SqlBackend -> IO SqlBackend)
-> (backend -> SqlBackend) -> backend -> IO SqlBackend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. backend -> SqlBackend
forall backend.
HasPersistBackend backend =>
backend -> BaseBackend backend
persistBackend
 where
  modifyBackend :: SqlBackend -> IO SqlBackend
modifyBackend SqlBackend
backend = do
    -- N.B. by default persistent caches a Map Text Statement for each
    -- SqlBackend, where Text is a SQL query. When we wrap a backend to run it
    -- with XRay, we have to modify each Statement to record query timing. If
    -- backends are long-lived, then this poses a problem because we will
    -- continually wrap the same Statement. Therefore, we clear this cache each
    -- time we want to monitor things with XRay.
    IORef (Map Text Statement)
newConnStmtMap <- Map Text Statement -> IO (IORef (Map Text Statement))
forall a. a -> IO (IORef a)
newIORef Map Text Statement
forall k a. Map k a
Map.empty
    SqlBackend -> IO SqlBackend
forall (f :: * -> *) a. Applicative f => a -> f a
pure SqlBackend
backend
      { connPrepare :: Text -> IO Statement
connPrepare = (Text -> IO Statement) -> Text -> IO Statement
connPrepare' (SqlBackend -> Text -> IO Statement
connPrepare SqlBackend
backend)
#if MIN_VERSION_persistent(2,9,0)
      , connBegin :: (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
connBegin = Text
-> ((Text -> IO Statement) -> Maybe IsolationLevel -> IO ())
-> (Text -> IO Statement)
-> Maybe IsolationLevel
-> IO ()
binaryTimerWrapper Text
"BEGIN" (SqlBackend
-> (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
connBegin SqlBackend
backend)
#else
      , connBegin = unaryTimerWrapper "BEGIN" (connBegin backend)
#endif
      , connCommit :: (Text -> IO Statement) -> IO ()
connCommit = Text
-> ((Text -> IO Statement) -> IO ())
-> (Text -> IO Statement)
-> IO ()
unaryTimerWrapper Text
"COMMIT" (SqlBackend -> (Text -> IO Statement) -> IO ()
connCommit SqlBackend
backend)
      , connRollback :: (Text -> IO Statement) -> IO ()
connRollback = Text
-> ((Text -> IO Statement) -> IO ())
-> (Text -> IO Statement)
-> IO ()
unaryTimerWrapper Text
"ROLLBACK" (SqlBackend -> (Text -> IO Statement) -> IO ()
connRollback SqlBackend
backend)
      , connStmtMap :: StatementCache
connStmtMap = IORef (Map Text Statement) -> StatementCache
mkCache IORef (Map Text Statement)
newConnStmtMap
      }

  connPrepare' :: (Text -> IO Statement) -> Text -> IO Statement
connPrepare' Text -> IO Statement
baseConnPrepare Text
sql = do
    -- Create an IORef to store the start time. This is populated when a query
    -- begins in 'stmtQuery', and is then used in stmtReset to compute the
    -- total time.
    IORef (Maybe POSIXTime)
startTimeIORef <- Maybe POSIXTime -> IO (IORef (Maybe POSIXTime))
forall a. a -> IO (IORef a)
newIORef Maybe POSIXTime
forall a. Maybe a
Nothing

    Statement
statement <- Text -> IO Statement
baseConnPrepare Text
sql
    Statement -> IO Statement
forall (f :: * -> *) a. Applicative f => a -> f a
pure Statement
statement
      { stmtQuery :: forall (m :: * -> *).
MonadIO m =>
[PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
stmtQuery = Statement
-> IORef (Maybe POSIXTime)
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m ())
forall (m :: * -> *).
MonadIO m =>
Statement
-> IORef (Maybe POSIXTime)
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m ())
stmtQuery' Statement
statement IORef (Maybe POSIXTime)
startTimeIORef
      , stmtReset :: IO ()
stmtReset = Statement -> IORef (Maybe POSIXTime) -> Text -> IO ()
stmtReset' Statement
statement IORef (Maybe POSIXTime)
startTimeIORef Text
sql
      }

  stmtQuery'
    :: forall m
     . MonadIO m
    => Statement
    -> IORef (Maybe POSIXTime)
    -> [PersistValue]
    -> Acquire (ConduitT () [PersistValue] m ())
  stmtQuery' :: Statement
-> IORef (Maybe POSIXTime)
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m ())
stmtQuery' Statement
statement IORef (Maybe POSIXTime)
startTimeIORef [PersistValue]
vals = do
    -- Record start time in IORef
    IO () -> Acquire ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Acquire ()) -> IO () -> Acquire ()
forall a b. (a -> b) -> a -> b
$ IO POSIXTime
getPOSIXTime IO POSIXTime -> (POSIXTime -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (Maybe POSIXTime) -> Maybe POSIXTime -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe POSIXTime)
startTimeIORef (Maybe POSIXTime -> IO ())
-> (POSIXTime -> Maybe POSIXTime) -> POSIXTime -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. POSIXTime -> Maybe POSIXTime
forall a. a -> Maybe a
Just

    -- Create the Source and return it
    Statement
-> [PersistValue] -> Acquire (ConduitT () [PersistValue] m ())
Statement
-> forall (m :: * -> *).
   MonadIO m =>
   [PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
stmtQuery Statement
statement [PersistValue]
vals

  stmtReset' :: Statement -> IORef (Maybe POSIXTime) -> Text -> IO ()
  stmtReset' :: Statement -> IORef (Maybe POSIXTime) -> Text -> IO ()
stmtReset' Statement
statement IORef (Maybe POSIXTime)
startTimeIORef Text
sql = do
    Statement -> IO ()
stmtReset Statement
statement

    -- If start time exists (it should) then send the trace
    Maybe POSIXTime
mStartTime <- IORef (Maybe POSIXTime) -> IO (Maybe POSIXTime)
forall a. IORef a -> IO a
readIORef IORef (Maybe POSIXTime)
startTimeIORef
    Maybe POSIXTime -> (POSIXTime -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe POSIXTime
mStartTime ((POSIXTime -> IO ()) -> IO ()) -> (POSIXTime -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \POSIXTime
startTime ->
      (XRaySegment -> IO ())
-> Text -> POSIXTime -> IORef StdGen -> Text -> IO ()
sendQueryTrace XRaySegment -> IO ()
sendTrace Text
subsegmentName POSIXTime
startTime IORef StdGen
stdGenIORef Text
sql

  unaryTimerWrapper :: Text
-> ((Text -> IO Statement) -> IO ())
-> (Text -> IO Statement)
-> IO ()
unaryTimerWrapper Text
sql (Text -> IO Statement) -> IO ()
action Text -> IO Statement
x = do
    POSIXTime
startTime <- IO POSIXTime
getPOSIXTime
    ()
result <- (Text -> IO Statement) -> IO ()
action Text -> IO Statement
x
    (XRaySegment -> IO ())
-> Text -> POSIXTime -> IORef StdGen -> Text -> IO ()
sendQueryTrace XRaySegment -> IO ()
sendTrace Text
sql POSIXTime
startTime IORef StdGen
stdGenIORef Text
sql
    () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
result

#if MIN_VERSION_persistent(2,9,0)
  binaryTimerWrapper :: Text
-> ((Text -> IO Statement) -> Maybe IsolationLevel -> IO ())
-> (Text -> IO Statement)
-> Maybe IsolationLevel
-> IO ()
binaryTimerWrapper Text
sql (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
action Text -> IO Statement
x Maybe IsolationLevel
y = do
    POSIXTime
startTime <- IO POSIXTime
getPOSIXTime
    ()
result <- (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
action Text -> IO Statement
x Maybe IsolationLevel
y
    (XRaySegment -> IO ())
-> Text -> POSIXTime -> IORef StdGen -> Text -> IO ()
sendQueryTrace XRaySegment -> IO ()
sendTrace Text
sql POSIXTime
startTime IORef StdGen
stdGenIORef Text
sql
    () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
result
#endif

sendQueryTrace
  :: (XRaySegment -> IO ())
  -> Text
  -> POSIXTime
  -> IORef StdGen
  -> Text
  -> IO ()
sendQueryTrace :: (XRaySegment -> IO ())
-> Text -> POSIXTime -> IORef StdGen -> Text -> IO ()
sendQueryTrace XRaySegment -> IO ()
sendTrace Text
subsegmentName POSIXTime
startTime IORef StdGen
stdGenIORef Text
sql = do
  -- Record end time
  POSIXTime
endTime <- IO POSIXTime
getPOSIXTime

  -- Generate trace and send it off
  XRaySegmentId
segmentId <- IORef StdGen
-> (StdGen -> (XRaySegmentId, StdGen)) -> IO XRaySegmentId
forall g a. RandomGen g => IORef g -> (g -> (a, g)) -> IO a
withRandomGenIORef IORef StdGen
stdGenIORef StdGen -> (XRaySegmentId, StdGen)
generateXRaySegmentId
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void
    (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ XRaySegment -> IO ()
sendTrace
    (XRaySegment -> IO ()) -> XRaySegment -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
-> XRaySegmentId -> POSIXTime -> Maybe POSIXTime -> XRaySegment
xraySubsegment Text
subsegmentName XRaySegmentId
segmentId POSIXTime
startTime (POSIXTime -> Maybe POSIXTime
forall a. a -> Maybe a
Just POSIXTime
endTime)
    XRaySegment -> (XRaySegment -> XRaySegment) -> XRaySegment
forall a b. a -> (a -> b) -> b
& (Maybe XRaySegmentSql -> Identity (Maybe XRaySegmentSql))
-> XRaySegment -> Identity XRaySegment
Lens' XRaySegment (Maybe XRaySegmentSql)
xraySegmentSql
    ((Maybe XRaySegmentSql -> Identity (Maybe XRaySegmentSql))
 -> XRaySegment -> Identity XRaySegment)
-> XRaySegmentSql -> XRaySegment -> XRaySegment
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ (XRaySegmentSql
xraySegmentSqlDef XRaySegmentSql
-> (XRaySegmentSql -> XRaySegmentSql) -> XRaySegmentSql
forall a b. a -> (a -> b) -> b
& (Maybe Text -> Identity (Maybe Text))
-> XRaySegmentSql -> Identity XRaySegmentSql
Lens' XRaySegmentSql (Maybe Text)
xraySegmentSqlSanitizedQuery ((Maybe Text -> Identity (Maybe Text))
 -> XRaySegmentSql -> Identity XRaySegmentSql)
-> Text -> XRaySegmentSql -> XRaySegmentSql
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Text
sql)

#if MIN_VERSION_persistent(2,13,3)
mkCache :: IORef (Map.Map Text Statement) -> StatementCache
mkCache :: IORef (Map Text Statement) -> StatementCache
mkCache = MkStatementCache -> StatementCache
mkStatementCache (MkStatementCache -> StatementCache)
-> (IORef (Map Text Statement) -> MkStatementCache)
-> IORef (Map Text Statement)
-> StatementCache
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (Map Text Statement) -> MkStatementCache
mkSimpleStatementCache
#else
mkCache :: IORef (Map.Map Text Statement) -> IORef (Map.Map Text Statement)
mkCache = id
#endif