{- |
  A Wai middleware that uniformly structures errors within a servant application.
  The library assumes all HTTP responses with status codes between @4xx@ and @5xx@ while
  lacking an @HTTP content-type@ are error responses. This assumption is derived
  from servant server error handling implementation.

  The formatting and structuring of errors rest on the implementation of 'HasErrorBody' class instances.
  It's class parameters are a content-type eg @JSON@ or @PlainText@ and a type-level list of
  @options@ e.g @'["error", "status"]@. The library offers instances for 'JSON' and 'PlainText' content-types.

  ==Sample usage with servant

  ===A typical servant application is usually of this form:

  @
  main :: IO ()
  main = run 8001 (serve proxyApi handlers)
  @

  ===With servant-errors as an error processing middleware:

  @
  main :: IO ()
  main = run 8001
     $ errorMw \@JSON \@\'["error", "status"]
     -- ^ Structures error response as JSON objects
     -- with @error@ and @status@ strings as error object field keys
     -- note they can be changed to any other preferred strings.
     $ serve proxyApi handlers
  @
-}
{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
module Network.Wai.Middleware.Servant.Errors
  ( -- * Error Middleware
    errorMw
  , errorMwDefJson

  -- * HasErrorBody class
  , HasErrorBody (..)

  -- * Helper functions and data types
  , ErrorMsg (..)
  , StatusCode (..)
  , ErrorLabels (..)
  , getErrorLabels
  , encodeAsJsonError
  , encodeAsPlainText
  )where

import Prelude.Compat
import Data.Aeson (Value (..), encode)
import qualified Data.ByteString as B
import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Lazy as LB
import qualified Data.HashMap.Strict as H
import Data.IORef (modifyIORef', newIORef, readIORef)
import Data.Kind (Type)
import Data.List (find)
import Data.Proxy (Proxy (..))
import Data.Scientific (Scientific)
import Data.String.Conversions (cs)
import qualified Data.Text as T
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import qualified Network.HTTP.Media as M
import Network.HTTP.Types (Header, Status (..), hContentType)
import Network.Wai (Middleware, Response, responseHeaders, responseLBS, responseStatus,
                    responseToStream)
import Servant.API.ContentTypes (Accept (..), JSON, PlainText)

-- | 'StatusCode' holds HTTP error status code
newtype StatusCode = StatusCode { StatusCode -> Int
unStatusCode :: Int }
  deriving (StatusCode -> StatusCode -> Bool
(StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool) -> Eq StatusCode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StatusCode -> StatusCode -> Bool
$c/= :: StatusCode -> StatusCode -> Bool
== :: StatusCode -> StatusCode -> Bool
$c== :: StatusCode -> StatusCode -> Bool
Eq, Eq StatusCode
Eq StatusCode =>
(StatusCode -> StatusCode -> Ordering)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> StatusCode)
-> (StatusCode -> StatusCode -> StatusCode)
-> Ord StatusCode
StatusCode -> StatusCode -> Bool
StatusCode -> StatusCode -> Ordering
StatusCode -> StatusCode -> StatusCode
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StatusCode -> StatusCode -> StatusCode
$cmin :: StatusCode -> StatusCode -> StatusCode
max :: StatusCode -> StatusCode -> StatusCode
$cmax :: StatusCode -> StatusCode -> StatusCode
>= :: StatusCode -> StatusCode -> Bool
$c>= :: StatusCode -> StatusCode -> Bool
> :: StatusCode -> StatusCode -> Bool
$c> :: StatusCode -> StatusCode -> Bool
<= :: StatusCode -> StatusCode -> Bool
$c<= :: StatusCode -> StatusCode -> Bool
< :: StatusCode -> StatusCode -> Bool
$c< :: StatusCode -> StatusCode -> Bool
compare :: StatusCode -> StatusCode -> Ordering
$ccompare :: StatusCode -> StatusCode -> Ordering
$cp1Ord :: Eq StatusCode
Ord, Int -> StatusCode -> ShowS
[StatusCode] -> ShowS
StatusCode -> String
(Int -> StatusCode -> ShowS)
-> (StatusCode -> String)
-> ([StatusCode] -> ShowS)
-> Show StatusCode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StatusCode] -> ShowS
$cshowList :: [StatusCode] -> ShowS
show :: StatusCode -> String
$cshow :: StatusCode -> String
showsPrec :: Int -> StatusCode -> ShowS
$cshowsPrec :: Int -> StatusCode -> ShowS
Show)

-- | 'ErrorMsg' holds HTTP error response body message
newtype ErrorMsg = ErrorMsg { ErrorMsg -> Text
unErrorMsg :: T.Text }
  deriving Int -> ErrorMsg -> ShowS
[ErrorMsg] -> ShowS
ErrorMsg -> String
(Int -> ErrorMsg -> ShowS)
-> (ErrorMsg -> String) -> ([ErrorMsg] -> ShowS) -> Show ErrorMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ErrorMsg] -> ShowS
$cshowList :: [ErrorMsg] -> ShowS
show :: ErrorMsg -> String
$cshow :: ErrorMsg -> String
showsPrec :: Int -> ErrorMsg -> ShowS
$cshowsPrec :: Int -> ErrorMsg -> ShowS
Show

-- | 'ErrorLabels' is a configuration for holding error response labels
data ErrorLabels = ErrorLabels
  { ErrorLabels -> Text
errName       :: T.Text
  , ErrorLabels -> Text
errStatusName :: T.Text
  }

-- | The 'HasErrorBody' class is used for structuring servant error responses.
--
-- @ctyp@ is an HTTP content-type with an 'Accept' class instance. eg @JSON@
--
-- @opts@ is a type level list for customising error and status labels.
--
-- For example:
-- @'["error-message", "status-code"]@
--
-- When @opts@ is left as an Empty type level list, it default's to a type list of these values:
-- @'["error", "status"]@ for the library provided 'JSON' and 'PlainText' instances.
--
class Accept ctyp => HasErrorBody (ctyp :: Type) (opts :: [Symbol]) where
  -- | 'encodeError' formats error response.
  -- The @opts@ type level list in the class definition is used by the 'getErrorLabels' function
  -- to obtain error labels which are subsequently used in implementing @encodeError@ for class instances
  encodeError :: StatusCode -> ErrorMsg -> LB.ByteString

instance  (KnownSymbol errLabel, KnownSymbol statusLabel)
  => HasErrorBody JSON '[errLabel, statusLabel] where
    encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsJsonError ((KnownSymbol errLabel, KnownSymbol statusLabel) => ErrorLabels
forall (errLabel :: Symbol) (statusLabel :: Symbol).
(KnownSymbol errLabel, KnownSymbol statusLabel) =>
ErrorLabels
getErrorLabels @errLabel @statusLabel)

instance HasErrorBody JSON '[] where
  encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = HasErrorBody JSON '["error", "status"] =>
StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @JSON @["error", "status"]

instance  (KnownSymbol errLabel, KnownSymbol statusLabel)
  => HasErrorBody PlainText '[errLabel, statusLabel] where
    encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsPlainText ((KnownSymbol errLabel, KnownSymbol statusLabel) => ErrorLabels
forall (errLabel :: Symbol) (statusLabel :: Symbol).
(KnownSymbol errLabel, KnownSymbol statusLabel) =>
ErrorLabels
getErrorLabels @errLabel @statusLabel)

instance HasErrorBody PlainText '[] where
  encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = HasErrorBody PlainText '["error", "status"] =>
StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @PlainText @["error", "status"]

-- | 'errorMwDefJson' is a convenience pre-configured function for middleware
-- that encodes error responses as @JSON@ objects using @error@ and @status@
-- for a @JSON object@ key fields
--
-- A resulting response may look like this:
-- @\{ error: \"failed to decode request body\", status: 400 \}@
--
errorMwDefJson :: Middleware
errorMwDefJson :: Middleware
errorMwDefJson = HasErrorBody JSON '[] => Middleware
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
Middleware
errorMw @JSON @'[]

-- | 'errorMw' functions provides "Network.Wai" middleware for formatting error responses
-- within a servant application.
-- Note that this function expects you to have @TypeApplications@ extension enabled
--
-- > errorMw @JSON @'[ "error", "status"]
--
errorMw :: forall ctyp opts. HasErrorBody ctyp opts => Middleware
errorMw :: Middleware
errorMw baseApp :: Application
baseApp req :: Request
req respond :: Response -> IO ResponseReceived
respond =
  Application
baseApp Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \ response :: Response
response -> do
     let status :: Status
status      = Response -> Status
responseStatus Response
response
         mcontentType :: Maybe Header
mcontentType = Response -> Maybe Header
getContentTypeHeader Response
response
     case (Status
status, Maybe Header
mcontentType) of
       (Status code :: Int
code _, Nothing) | Int
code Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= 400 Bool -> Bool -> Bool
&& Int
code Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 600 ->
         Status -> Response -> IO Response
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
Status -> Response -> IO Response
newResponse @ctyp @opts Status
status Response
response IO Response
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Response -> IO ResponseReceived
respond
       _                                     -> Response -> IO ResponseReceived
respond Response
response
  where
    getContentTypeHeader :: Response -> Maybe Header
    getContentTypeHeader :: Response -> Maybe Header
getContentTypeHeader = (Header -> Bool) -> [Header] -> Maybe Header
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((HeaderName
hContentType HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
==) (HeaderName -> Bool) -> (Header -> HeaderName) -> Header -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> HeaderName
forall a b. (a, b) -> a
fst) ([Header] -> Maybe Header)
-> (Response -> [Header]) -> Response -> Maybe Header
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> [Header]
responseHeaders


-- | 'newResponse' creates new API route 'Response' content based on a 'HasErrorBody' instance
--
-- In the event that the original error response has an empty error message body e.g. a 404 error.
-- The error status message is used as the error message.
newResponse
  :: forall ctyp opts . HasErrorBody ctyp opts
  => Status
  -> Response
  -> IO Response
newResponse :: Status -> Response -> IO Response
newResponse status :: Status
status@(Status code :: Int
code statusMsg :: ByteString
statusMsg) response :: Response
response = do
  ByteString
body <- Response -> IO ByteString
responseBody Response
response
  let oldHeaders :: [Header]
oldHeaders = Response -> [Header]
responseHeaders Response
response
  let newHeaders :: [Header]
newHeaders = (HeaderName
hContentType,  MediaType -> ByteString
forall h. RenderHeader h => h -> ByteString
M.renderHeader (MediaType -> ByteString) -> MediaType -> ByteString
forall a b. (a -> b) -> a -> b
$ Proxy ctyp -> MediaType
forall k (ctype :: k). Accept ctype => Proxy ctype -> MediaType
contentType (Proxy ctyp
forall k (t :: k). Proxy t
Proxy @ctyp)) Header -> [Header] -> [Header]
forall a. a -> [a] -> [a]
: [Header]
oldHeaders
      content :: ErrorMsg
content = Text -> ErrorMsg
ErrorMsg (Text -> ErrorMsg)
-> (ByteString -> Text) -> ByteString -> ErrorMsg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> ErrorMsg) -> ByteString -> ErrorMsg
forall a b. (a -> b) -> a -> b
$ if ByteString
body ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
forall a. Monoid a => a
mempty then ByteString
statusMsg else ByteString
body
      newContent :: ByteString
newContent = StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @ctyp @opts (Int -> StatusCode
StatusCode Int
code) ErrorMsg
content
  Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$ Status -> [Header] -> ByteString -> Response
responseLBS Status
status [Header]
newHeaders ByteString
newContent

-- | 'responseBody' extracts response body from the servant server response.
responseBody :: Response -> IO B.ByteString
responseBody :: Response -> IO ByteString
responseBody res :: Response
res =
  let (_status :: Status
_status, _headers :: [Header]
_headers, streamBody :: (StreamingBody -> IO a) -> IO a
streamBody) = Response -> (Status, [Header], (StreamingBody -> IO a) -> IO a)
forall a.
Response -> (Status, [Header], (StreamingBody -> IO a) -> IO a)
responseToStream Response
res in
  (StreamingBody -> IO ByteString) -> IO ByteString
forall a. (StreamingBody -> IO a) -> IO a
streamBody ((StreamingBody -> IO ByteString) -> IO ByteString)
-> (StreamingBody -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \f :: StreamingBody
f -> do
    IORef Builder
content <- Builder -> IO (IORef Builder)
forall a. a -> IO (IORef a)
newIORef Builder
forall a. Monoid a => a
mempty
    StreamingBody
f (\chunk :: Builder
chunk -> IORef Builder -> (Builder -> Builder) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Builder
content (Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
chunk)) (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    ByteString -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
toLazyByteString (Builder -> ByteString) -> IO Builder -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef Builder -> IO Builder
forall a. IORef a -> IO a
readIORef IORef Builder
content

{-------------------------------------------------------------------------------
  Helper functions for defining instances
-------------------------------------------------------------------------------}

-- | 'encodeAsJsonError' formats error response into 'JSON' encoded string.
-- Its used in the library provided 'HasErrorBody' /JSON/ instance
encodeAsJsonError :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsJsonError :: ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsJsonError ErrorLabels {..} code :: StatusCode
code content :: ErrorMsg
content =
  Value -> ByteString
forall a. ToJSON a => a -> ByteString
encode (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$ Object -> Value
Object
         (Object -> Value) -> Object -> Value
forall a b. (a -> b) -> a -> b
$ [(Text, Value)] -> Object
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
H.fromList
           [ (Text
errName, Text -> Value
String (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ ErrorMsg -> Text
unErrorMsg ErrorMsg
content)
           , (Text
errStatusName, Scientific -> Value
Number (Scientific -> Value) -> Scientific -> Value
forall a b. (a -> b) -> a -> b
$ StatusCode -> Scientific
toScientific StatusCode
code )
           ]
   where
     toScientific :: StatusCode -> Scientific
     toScientific :: StatusCode -> Scientific
toScientific = Integer -> Scientific
forall a. Num a => Integer -> a
fromInteger (Integer -> Scientific)
-> (StatusCode -> Integer) -> StatusCode -> Scientific
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integral Int, Num Integer) => Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Integer (Int -> Integer) -> (StatusCode -> Int) -> StatusCode -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StatusCode -> Int
unStatusCode

-- | 'encodeAsPlainText' formats error response into 'PlainText' string.
-- its used in the library provided 'HasErrorBody' /PlainText/ class instance
encodeAsPlainText :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsPlainText :: ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsPlainText ErrorLabels {..} code :: StatusCode
code content :: ErrorMsg
content =
  Text -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$  Text
errName
     Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ErrorMsg -> Text
unErrorMsg ErrorMsg
content
     Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
errStatusName
     Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ StatusCode -> Int
unStatusCode StatusCode
code)

-- | 'getErrorLabels' is used to tranform type level list options provided via the
-- 'HasErrorBody' class into an 'ErrorLabels' data type.
--
-- 'ErrorLabels' is used with the error formatting and encoding
-- functions used in \HasErrorBody\ class.
getErrorLabels
  :: forall errLabel statusLabel .(KnownSymbol errLabel, KnownSymbol statusLabel)
  => ErrorLabels
getErrorLabels :: ErrorLabels
getErrorLabels = Text -> Text -> ErrorLabels
ErrorLabels (Proxy errLabel -> Text
forall (t :: Symbol). KnownSymbol t => Proxy t -> Text
label (Proxy errLabel
forall k (t :: k). Proxy t
Proxy @errLabel)) (Proxy statusLabel -> Text
forall (t :: Symbol). KnownSymbol t => Proxy t -> Text
label (Proxy statusLabel
forall k (t :: k). Proxy t
Proxy @statusLabel))
  where
    label :: KnownSymbol t => Proxy t -> T.Text
    label :: Proxy t -> Text
label proxy :: Proxy t
proxy = String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Proxy t -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal Proxy t
proxy