module Network.Wai.Middleware.Logging
  ( addThreadContext
  , addThreadContextFromRequest
  , requestLogger
  , requestLoggerWith

    -- * Configuration
  , Config
  , defaultConfig
  , setConfigLogSource
  , setConfigGetClientIp
  , setConfigGetDestinationIp
  ) where

import Prelude

import Blammo.Logging
import Control.Applicative ((<|>))
import Control.Arrow ((***))
import Control.Monad.IO.Unlift (withRunInIO)
import Data.Aeson
import qualified Data.Aeson.Compat as Key
import qualified Data.Aeson.Compat as KeyMap
import Data.ByteString (ByteString)
import qualified Data.CaseInsensitive as CI
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (Text, pack)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Network.HTTP.Types.Header (Header, HeaderName)
import Network.HTTP.Types.Status (Status (..))
import Network.Wai
  ( Middleware
  , Request
  , Response
  , rawPathInfo
  , rawQueryString
  , remoteHost
  , requestHeaders
  , requestMethod
  , responseHeaders
  , responseStatus
  )
import qualified System.Clock as Clock

-- | Add context to any logging done from the request-handling thread
addThreadContext :: [Pair] -> Middleware
addThreadContext :: [Pair] -> Middleware
addThreadContext = (Request -> [Pair]) -> Middleware
addThreadContextFromRequest ((Request -> [Pair]) -> Middleware)
-> ([Pair] -> Request -> [Pair]) -> [Pair] -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Pair] -> Request -> [Pair]
forall a b. a -> b -> a
const

-- | 'addThreadContext', but have the 'Request' available
addThreadContextFromRequest :: (Request -> [Pair]) -> Middleware
addThreadContextFromRequest :: (Request -> [Pair]) -> Middleware
addThreadContextFromRequest Request -> [Pair]
toContext Application
app Request
request Response -> IO ResponseReceived
respond = do
  [Pair] -> IO ResponseReceived -> IO ResponseReceived
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Pair] -> m a -> m a
withThreadContext (Request -> [Pair]
toContext Request
request) (IO ResponseReceived -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ do
    Application
app Request
request Response -> IO ResponseReceived
respond

-- | Log requests (more accurately, responses) as they happen
--
-- In JSON format, logged messages look like:
--
-- @
-- {
--   ...
--   message: {
--     text: "GET /foo/bar => 200 OK",
--     meta: {
--       method: "GET",
--       path: "/foo/bar",
--       query: "?baz=bat&quix=quo",
--       status: {
--         code: 200,
--         message: "OK"
--       },
--       durationMs: 1322.2,
--       requestHeaders: {
--         Authorization: "***",
--         Accept: "text/html",
--         Cookie: "***"
--       },
--       responseHeaders: {
--         Set-Cookie: "***",
--         Expires: "never"
--       }
--     }
--   }
-- }
-- @
requestLogger :: HasLogger env => env -> Middleware
requestLogger :: forall env. HasLogger env => env -> Middleware
requestLogger = Config -> env -> Middleware
forall env. HasLogger env => Config -> env -> Middleware
requestLoggerWith Config
defaultConfig

data Config = Config
  { Config -> LogSource
cLogSource :: LogSource
  , Config -> Request -> LogSource
cGetClientIp :: Request -> Text
  , Config -> Request -> Maybe LogSource
cGetDestinationIp :: Request -> Maybe Text
  }

defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
  Config
    { cLogSource :: LogSource
cLogSource = LogSource
"requestLogger"
    , cGetClientIp :: Request -> LogSource
cGetClientIp = \Request
req ->
        LogSource -> Maybe LogSource -> LogSource
forall a. a -> Maybe a -> a
fromMaybe (String -> LogSource
pack (String -> LogSource) -> String -> LogSource
forall a b. (a -> b) -> a -> b
$ SockAddr -> String
forall a. Show a => a -> String
show (SockAddr -> String) -> SockAddr -> String
forall a b. (a -> b) -> a -> b
$ Request -> SockAddr
remoteHost Request
req) (Maybe LogSource -> LogSource) -> Maybe LogSource -> LogSource
forall a b. (a -> b) -> a -> b
$
          (LogSource -> Maybe LogSource
firstValue (LogSource -> Maybe LogSource)
-> Maybe LogSource -> Maybe LogSource
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-forwarded-for" Request
req)
            Maybe LogSource -> Maybe LogSource -> Maybe LogSource
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-real-ip" Request
req
    , cGetDestinationIp :: Request -> Maybe LogSource
cGetDestinationIp = HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-real-ip"
    }
 where
  firstValue :: LogSource -> Maybe LogSource
firstValue = (LogSource -> Bool) -> [LogSource] -> Maybe LogSource
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Bool -> Bool
not (Bool -> Bool) -> (LogSource -> Bool) -> LogSource -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LogSource -> Bool
T.null) ([LogSource] -> Maybe LogSource)
-> (LogSource -> [LogSource]) -> LogSource -> Maybe LogSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LogSource -> LogSource) -> [LogSource] -> [LogSource]
forall a b. (a -> b) -> [a] -> [b]
map LogSource -> LogSource
T.strip ([LogSource] -> [LogSource])
-> (LogSource -> [LogSource]) -> LogSource -> [LogSource]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => LogSource -> LogSource -> [LogSource]
LogSource -> LogSource -> [LogSource]
T.splitOn LogSource
","

lookupRequestHeader :: HeaderName -> Request -> Maybe Text
lookupRequestHeader :: HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
h = (ByteString -> LogSource) -> Maybe ByteString -> Maybe LogSource
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> LogSource
decodeUtf8 (Maybe ByteString -> Maybe LogSource)
-> (Request -> Maybe ByteString) -> Request -> Maybe LogSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
h ([(HeaderName, ByteString)] -> Maybe ByteString)
-> (Request -> [(HeaderName, ByteString)])
-> Request
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> [(HeaderName, ByteString)]
requestHeaders

-- | Change the source used for log messages
--
-- Default is @requestLogger@.
setConfigLogSource :: LogSource -> Config -> Config
setConfigLogSource :: LogSource -> Config -> Config
setConfigLogSource LogSource
x Config
c = Config
c {cLogSource = x}

-- | Change how the @clientIp@ field is determined
--
-- Default is looking up the first value in @x-forwarded-for@, then the
-- @x-real-ip@ header, then finally falling back to 'Network.Wai.remoteHost'.
setConfigGetClientIp :: (Request -> Text) -> Config -> Config
setConfigGetClientIp :: (Request -> LogSource) -> Config -> Config
setConfigGetClientIp Request -> LogSource
x Config
c = Config
c {cGetClientIp = x}

-- | Change how the @destinationIp@ field is determined
--
-- Default is looking up the @x-real-ip@ header.
--
-- __NOTE__: Our default uses a somewhat loose definition of /destination/. It
-- would be more accurate to report the resolved IP address of the @Host@
-- header, but we don't have that available. Our default of @x-real-ip@ favors
-- containerized Warp on AWS/ECS, where this value holds the ECS target
-- container's IP address. This is valuable debugging information and could, if
-- you squint, be considered a /destination/.
setConfigGetDestinationIp :: (Request -> Maybe Text) -> Config -> Config
setConfigGetDestinationIp :: (Request -> Maybe LogSource) -> Config -> Config
setConfigGetDestinationIp Request -> Maybe LogSource
x Config
c = Config
c {cGetDestinationIp = x}

requestLoggerWith :: HasLogger env => Config -> env -> Middleware
requestLoggerWith :: forall env. HasLogger env => Config -> env -> Middleware
requestLoggerWith Config
config env
env Application
app Request
req Response -> IO ResponseReceived
respond =
  ((forall a. IO a -> IO a) -> IO ResponseReceived)
-> IO ResponseReceived
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. IO a -> IO a) -> IO ResponseReceived)
 -> IO ResponseReceived)
-> ((forall a. IO a -> IO a) -> IO ResponseReceived)
-> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
runInIO -> do
    TimeSpec
begin <- IO TimeSpec
getTime
    Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
      ResponseReceived
recvd <- Response -> IO ResponseReceived
respond Response
resp
      Double
duration <- TimeSpec -> Double
toMillis (TimeSpec -> Double)
-> (TimeSpec -> TimeSpec) -> TimeSpec -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeSpec -> TimeSpec -> TimeSpec
forall a. Num a => a -> a -> a
subtract TimeSpec
begin (TimeSpec -> Double) -> IO TimeSpec -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO TimeSpec
getTime
      ResponseReceived
recvd ResponseReceived -> IO () -> IO ResponseReceived
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IO () -> IO ()
forall a. IO a -> IO a
runInIO (env -> WithLogger env IO () -> IO ()
forall env (m :: * -> *) a. env -> WithLogger env m a -> m a
runWithLogger env
env (WithLogger env IO () -> IO ()) -> WithLogger env IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Config -> Double -> Request -> Response -> WithLogger env IO ()
forall (m :: * -> *).
MonadLogger m =>
Config -> Double -> Request -> Response -> m ()
logResponse Config
config Double
duration Request
req Response
resp)
 where
  getTime :: IO TimeSpec
getTime = Clock -> IO TimeSpec
Clock.getTime Clock
Clock.Monotonic
  toMillis :: TimeSpec -> Double
toMillis TimeSpec
x = Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
Clock.toNanoSecs TimeSpec
x) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
nsPerMs

logResponse :: MonadLogger m => Config -> Double -> Request -> Response -> m ()
logResponse :: forall (m :: * -> *).
MonadLogger m =>
Config -> Double -> Request -> Response -> m ()
logResponse Config {LogSource
Request -> Maybe LogSource
Request -> LogSource
cLogSource :: Config -> LogSource
cGetClientIp :: Config -> Request -> LogSource
cGetDestinationIp :: Config -> Request -> Maybe LogSource
cLogSource :: LogSource
cGetClientIp :: Request -> LogSource
cGetDestinationIp :: Request -> Maybe LogSource
..} Double
duration Request
req Response
resp
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
500 = LogSource -> Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logErrorNS LogSource
cLogSource (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
404 = LogSource -> Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logDebugNS LogSource
cLogSource (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
400 = LogSource -> Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logWarnNS LogSource
cLogSource (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Bool
otherwise = LogSource -> Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logDebugNS LogSource
cLogSource (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
 where
  message :: LogSource
message =
    ByteString -> LogSource
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> LogSource
" "
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> LogSource
" => "
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> String -> LogSource
pack (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode Status
status)
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> LogSource
" "
      LogSource -> LogSource -> LogSource
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogSource
decodeUtf8 (Status -> ByteString
statusMessage Status
status)

  details :: [SeriesElem]
details =
    [ Key
"method" Key -> LogSource -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
    , Key
"path" Key -> LogSource -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
    , Key
"query" Key -> LogSource -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawQueryString Request
req)
    , Key
"status"
        Key -> Value -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [Pair] -> Value
object
          [ Key
"code" Key -> Int -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Status -> Int
statusCode Status
status
          , Key
"message" Key -> LogSource -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Status -> ByteString
statusMessage Status
status)
          ]
    , Key
"clientIp" Key -> LogSource -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Request -> LogSource
cGetClientIp Request
req
    , Key
"destinationIp" Key -> Maybe LogSource -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Request -> Maybe LogSource
cGetDestinationIp Request
req
    , Key
"durationMs" Key -> Double -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Double
duration
    , Key
"requestHeaders"
        Key -> Value -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> [(HeaderName, ByteString)] -> Value
headerObject [HeaderName
"authorization", HeaderName
"cookie"] (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req)
    , Key
"responseHeaders" Key -> Value -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> [(HeaderName, ByteString)] -> Value
headerObject [HeaderName
"set-cookie"] (Response -> [(HeaderName, ByteString)]
responseHeaders Response
resp)
    ]

  status :: Status
status = Response -> Status
responseStatus Response
resp

headerObject :: [HeaderName] -> [Header] -> Value
headerObject :: [HeaderName] -> [(HeaderName, ByteString)] -> Value
headerObject [HeaderName]
redact = Object -> Value
Object (Object -> Value)
-> ([(HeaderName, ByteString)] -> Object)
-> [(HeaderName, ByteString)]
-> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Pair] -> Object
forall v. [(Key, v)] -> KeyMap v
KeyMap.fromList ([Pair] -> Object)
-> ([(HeaderName, ByteString)] -> [Pair])
-> [(HeaderName, ByteString)]
-> Object
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderName, ByteString) -> Pair)
-> [(HeaderName, ByteString)] -> [Pair]
forall a b. (a -> b) -> [a] -> [b]
map ((HeaderName, ByteString) -> Pair
mung ((HeaderName, ByteString) -> Pair)
-> ((HeaderName, ByteString) -> (HeaderName, ByteString))
-> (HeaderName, ByteString)
-> Pair
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderName, ByteString) -> (HeaderName, ByteString)
forall {b}. IsString b => (HeaderName, b) -> (HeaderName, b)
hide)
 where
  mung :: (HeaderName, ByteString) -> Pair
mung = LogSource -> Key
Key.fromText (LogSource -> Key)
-> (HeaderName -> LogSource) -> HeaderName -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LogSource
decodeUtf8 (ByteString -> LogSource)
-> (HeaderName -> ByteString) -> HeaderName -> LogSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> ByteString
forall s. CI s -> s
CI.foldedCase (HeaderName -> Key)
-> (ByteString -> Value) -> (HeaderName, ByteString) -> Pair
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** LogSource -> Value
String (LogSource -> Value)
-> (ByteString -> LogSource) -> ByteString -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LogSource
decodeUtf8
  hide :: (HeaderName, b) -> (HeaderName, b)
hide (HeaderName
k, b
v)
    | HeaderName
k HeaderName -> [HeaderName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
redact = (HeaderName
k, b
"***")
    | Bool
otherwise = (HeaderName
k, b
v)

nsPerMs :: Double
nsPerMs :: Double
nsPerMs = Double
1000000

decodeUtf8 :: ByteString -> Text
decodeUtf8 :: ByteString -> LogSource
decodeUtf8 = OnDecodeError -> ByteString -> LogSource
decodeUtf8With OnDecodeError
lenientDecode