{-# OPTIONS_GHC -Wno-deprecations #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}

module Network.Wai.Middleware.Validation where

import           Data.Aeson                                 (ToJSON, encode, object, toJSON, (.=))
import           Data.ByteString.Builder                    (toLazyByteString)
import qualified Data.ByteString.Char8                      as S8
import qualified Data.ByteString.Lazy                       as L
import           Data.IORef                                 (atomicModifyIORef, newIORef, readIORef)
import           Network.HTTP.Types                         (ResponseHeaders, StdMethod,
                                                             badRequest400, hContentType,
                                                             internalServerError500, parseMethod,
                                                             statusCode, statusIsSuccessful)
import           Network.Wai                                (Middleware, Request, Response,
                                                             rawPathInfo, requestBody,
                                                             requestMethod, responseLBS,
                                                             responseStatus, responseToStream,
                                                             strictRequestBody)

import           Network.Wai.Middleware.Validation.Internal (ApiDefinition, getRequestBodySchema,
                                                             getResponseBodySchema, toApiDefinition,
                                                             validateJsonDocument)


data DefaultErrorJson = DefaultErrorJson
    { DefaultErrorJson -> String
title  :: String
    , DefaultErrorJson -> String
detail :: String
    } deriving (Int -> DefaultErrorJson -> ShowS
[DefaultErrorJson] -> ShowS
DefaultErrorJson -> String
(Int -> DefaultErrorJson -> ShowS)
-> (DefaultErrorJson -> String)
-> ([DefaultErrorJson] -> ShowS)
-> Show DefaultErrorJson
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DefaultErrorJson] -> ShowS
$cshowList :: [DefaultErrorJson] -> ShowS
show :: DefaultErrorJson -> String
$cshow :: DefaultErrorJson -> String
showsPrec :: Int -> DefaultErrorJson -> ShowS
$cshowsPrec :: Int -> DefaultErrorJson -> ShowS
Show)

instance ToJSON DefaultErrorJson where
    toJSON :: DefaultErrorJson -> Value
toJSON (DefaultErrorJson String
t String
d) = [Pair] -> Value
object [Text
"title" Text -> String -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Text -> v -> kv
.= String
t, Text
"detail" Text -> String -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Text -> v -> kv
.= String
d]

-- | Make error string with JSON.
mkDefaultErrorJson :: String -> DefaultErrorJson
mkDefaultErrorJson :: String -> DefaultErrorJson
mkDefaultErrorJson = String -> String -> DefaultErrorJson
DefaultErrorJson String
"Validation failed"


-- | Make a middleware for Request/Response validation.
mkValidator' :: L.ByteString -> Maybe Middleware
mkValidator' :: ByteString -> Maybe Middleware
mkValidator' = (String -> DefaultErrorJson) -> ByteString -> Maybe Middleware
forall a.
ToJSON a =>
(String -> a) -> ByteString -> Maybe Middleware
mkValidator String -> DefaultErrorJson
mkDefaultErrorJson

mkValidator :: ToJSON a => (String -> a) -> L.ByteString -> Maybe Middleware
mkValidator :: (String -> a) -> ByteString -> Maybe Middleware
mkValidator String -> a
mkErrorJson ByteString
apiJson = Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) (Middleware -> Middleware -> Middleware)
-> Maybe Middleware -> Maybe (Middleware -> Middleware)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Middleware
mResValidator Maybe (Middleware -> Middleware)
-> Maybe Middleware -> Maybe Middleware
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Middleware
mReqValidator
  where
    mApiDef :: Maybe ApiDefinition
mApiDef = ByteString -> Maybe ApiDefinition
toApiDefinition ByteString
apiJson
    mReqValidator :: Maybe Middleware
mReqValidator = (String -> a) -> ApiDefinition -> Middleware
forall a. ToJSON a => (String -> a) -> ApiDefinition -> Middleware
requestValidator String -> a
mkErrorJson (ApiDefinition -> Middleware)
-> Maybe ApiDefinition -> Maybe Middleware
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ApiDefinition
mApiDef
    mResValidator :: Maybe Middleware
mResValidator = (String -> a) -> ApiDefinition -> Middleware
forall a. ToJSON a => (String -> a) -> ApiDefinition -> Middleware
responseValidator String -> a
mkErrorJson (ApiDefinition -> Middleware)
-> Maybe ApiDefinition -> Maybe Middleware
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ApiDefinition
mApiDef

-- | Make a middleware for Requestion validation.
mkRequestValidator' :: L.ByteString -> Maybe Middleware
mkRequestValidator' :: ByteString -> Maybe Middleware
mkRequestValidator' = (String -> DefaultErrorJson) -> ByteString -> Maybe Middleware
forall a.
ToJSON a =>
(String -> a) -> ByteString -> Maybe Middleware
mkRequestValidator String -> DefaultErrorJson
mkDefaultErrorJson

mkRequestValidator :: ToJSON a => (String -> a) -> L.ByteString -> Maybe Middleware
mkRequestValidator :: (String -> a) -> ByteString -> Maybe Middleware
mkRequestValidator String -> a
mkErrorJson ByteString
apiJson = (String -> a) -> ApiDefinition -> Middleware
forall a. ToJSON a => (String -> a) -> ApiDefinition -> Middleware
requestValidator String -> a
mkErrorJson (ApiDefinition -> Middleware)
-> Maybe ApiDefinition -> Maybe Middleware
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe ApiDefinition
toApiDefinition ByteString
apiJson

-- | Make a middleware for Response validation.
mkResponseValidator' :: L.ByteString -> Maybe Middleware
mkResponseValidator' :: ByteString -> Maybe Middleware
mkResponseValidator' = (String -> DefaultErrorJson) -> ByteString -> Maybe Middleware
forall a.
ToJSON a =>
(String -> a) -> ByteString -> Maybe Middleware
mkResponseValidator String -> DefaultErrorJson
mkDefaultErrorJson

mkResponseValidator :: ToJSON a => (String -> a) -> L.ByteString -> Maybe Middleware
mkResponseValidator :: (String -> a) -> ByteString -> Maybe Middleware
mkResponseValidator String -> a
mkErrorJson ByteString
apiJson = (String -> a) -> ApiDefinition -> Middleware
forall a. ToJSON a => (String -> a) -> ApiDefinition -> Middleware
responseValidator String -> a
mkErrorJson (ApiDefinition -> Middleware)
-> Maybe ApiDefinition -> Maybe Middleware
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe ApiDefinition
toApiDefinition ByteString
apiJson

requestValidator :: ToJSON a => (String -> a) -> ApiDefinition -> Middleware
requestValidator :: (String -> a) -> ApiDefinition -> Middleware
requestValidator String -> a
mkErrorJson ApiDefinition
apiDef Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
    let
        eMethod :: Either ByteString StdMethod
eMethod = ByteString -> Either ByteString StdMethod
parseMethod (ByteString -> Either ByteString StdMethod)
-> ByteString -> Either ByteString StdMethod
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req
        path :: String
path = ByteString -> String
S8.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
rawPathInfo Request
req
        mBodySchema :: Maybe BodySchema
mBodySchema = case Either ByteString StdMethod
eMethod of
            Right StdMethod
method -> ApiDefinition -> StdMethod -> String -> Maybe BodySchema
getRequestBodySchema ApiDefinition
apiDef StdMethod
method String
path
            Either ByteString StdMethod
_            -> Maybe BodySchema
forall a. Maybe a
Nothing

    String -> IO ()
putStrLn String
">>> [Request]"
    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Method: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Either ByteString StdMethod -> String
forall a. Show a => a -> String
show Either ByteString StdMethod
eMethod
    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Path: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
path

    case Maybe BodySchema
mBodySchema of
        Maybe BodySchema
Nothing         -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
        Just BodySchema
bodySchema -> do
            (ByteString
body, Request
newReq) <- Request -> IO (ByteString, Request)
getRequestBody Request
req
            String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Body: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
body

            case ApiDefinition -> BodySchema -> ByteString -> Either String [String]
validateJsonDocument ApiDefinition
apiDef BodySchema
bodySchema ByteString
body of
                Right []   -> Application
app Request
newReq Response -> IO ResponseReceived
sendResponse
                Right [String]
errs -> String -> IO ResponseReceived
respondError (String -> IO ResponseReceived) -> String -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [String]
errs
                Left String
err   -> String -> IO ResponseReceived
respondError String
err
  where
    respondError :: String -> IO ResponseReceived
respondError String
msg = Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
badRequest400 [(HeaderName
hContentType, ByteString
"application/json")] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ a -> ByteString
forall a. ToJSON a => a -> ByteString
encode (a -> ByteString) -> a -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> a
mkErrorJson String
msg

responseValidator :: ToJSON a => (String -> a) -> ApiDefinition -> Middleware
responseValidator :: (String -> a) -> ApiDefinition -> Middleware
responseValidator String -> a
mkErrorJson ApiDefinition
apiDef Application
app Request
req Response -> IO ResponseReceived
sendResponse = Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
res -> do
    let status :: Status
status = Response -> Status
responseStatus Response
res
    -- Validate only the success response.
    if Status -> Bool
statusIsSuccessful Status
status
        then do
            let
                eMethod :: Either ByteString StdMethod
eMethod = ByteString -> Either ByteString StdMethod
parseMethod (ByteString -> Either ByteString StdMethod)
-> ByteString -> Either ByteString StdMethod
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req
                path :: String
path = ByteString -> String
S8.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
rawPathInfo Request
req
                statusCode' :: Int
statusCode' = Status -> Int
statusCode Status
status
                mBodySchema :: Maybe BodySchema
mBodySchema = case Either ByteString StdMethod
eMethod of
                    Right StdMethod
method -> ApiDefinition -> StdMethod -> String -> Int -> Maybe BodySchema
getResponseBodySchema ApiDefinition
apiDef StdMethod
method String
path Int
statusCode'
                    Either ByteString StdMethod
_            -> Maybe BodySchema
forall a. Maybe a
Nothing

            String -> IO ()
putStrLn String
">>> [Response]"
            String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Method: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Either ByteString StdMethod -> String
forall a. Show a => a -> String
show Either ByteString StdMethod
eMethod
            String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Path: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
path
            String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Status: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
statusCode'

            case Maybe BodySchema
mBodySchema of
                Maybe BodySchema
Nothing         -> Response -> IO ResponseReceived
sendResponse Response
res
                Just BodySchema
bodySchema -> do
                    ByteString
body <- Response -> IO ByteString
getResponseBody Response
res
                    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
">>> Body': " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
body

                    case ApiDefinition -> BodySchema -> ByteString -> Either String [String]
validateJsonDocument ApiDefinition
apiDef BodySchema
bodySchema ByteString
body of
                        Right []   -> Response -> IO ResponseReceived
sendResponse Response
res
                        -- REVIEW: It may be better not to include the error details in the response.
                        -- _ -> respondError "Invalid response body"
                        Right [String]
errs -> String -> IO ResponseReceived
respondError (String -> IO ResponseReceived) -> String -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [String]
errs
                        Left String
err   -> String -> IO ResponseReceived
respondError String
err

        else Response -> IO ResponseReceived
sendResponse Response
res
  where
    respondError :: String -> IO ResponseReceived
respondError String
msg = Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
internalServerError500 [(HeaderName
hContentType, ByteString
"application/json")] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ a -> ByteString
forall a. ToJSON a => a -> ByteString
encode (a -> ByteString) -> a -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> a
mkErrorJson String
msg

getRequestBody :: Request -> IO (L.ByteString, Request)
getRequestBody :: Request -> IO (ByteString, Request)
getRequestBody Request
req = do
    ByteString
body <- Request -> IO ByteString
strictRequestBody Request
req
    -- The body has been consumed and needs to be refilled.
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
body
    let newRequestBody :: IO ByteString
newRequestBody = IORef ByteString
-> (ByteString -> (ByteString, ByteString)) -> IO ByteString
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef ByteString
ref (ByteString
L.empty,)
    let newReq :: Request
newReq = Request
req { requestBody :: IO ByteString
requestBody = ByteString -> ByteString
L.toStrict (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
newRequestBody }
    (ByteString, Request) -> IO (ByteString, Request)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
body, Request
newReq)

getResponseBody :: Response -> IO L.ByteString
getResponseBody :: Response -> IO ByteString
getResponseBody Response
res = do
    let (Status
_, ResponseHeaders
_, (StreamingBody -> IO a) -> IO a
withBody) = Response
-> (Status, ResponseHeaders, (StreamingBody -> IO a) -> IO a)
forall a.
Response
-> (Status, ResponseHeaders, (StreamingBody -> IO a) -> IO a)
responseToStream Response
res
    (StreamingBody -> IO ByteString) -> IO ByteString
forall a. (StreamingBody -> IO a) -> IO a
withBody ((StreamingBody -> IO ByteString) -> IO ByteString)
-> (StreamingBody -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \StreamingBody
streamingBody -> do
        IORef Builder
ref <- Builder -> IO (IORef Builder)
forall a. a -> IO (IORef a)
newIORef Builder
forall a. Monoid a => a
mempty
        StreamingBody
streamingBody
            (\Builder
b -> IORef Builder -> (Builder -> (Builder, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef Builder
ref ((Builder -> (Builder, ())) -> IO ())
-> (Builder -> (Builder, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Builder
acc -> (Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
b, ()))
            (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
        Builder -> ByteString
toLazyByteString (Builder -> ByteString) -> IO Builder -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef Builder -> IO Builder
forall a. IORef a -> IO a
readIORef IORef Builder
ref

responseHeaders :: ResponseHeaders
responseHeaders :: ResponseHeaders
responseHeaders = [(HeaderName
hContentType, ByteString
"application/json")]

--
-- for non-middleware use
--

validateRequestBody :: StdMethod -> FilePath -> L.ByteString -> L.ByteString -> Either String [String]
validateRequestBody :: StdMethod
-> String -> ByteString -> ByteString -> Either String [String]
validateRequestBody StdMethod
method String
path ByteString
apiJson ByteString
body = case ByteString -> Maybe ApiDefinition
toApiDefinition ByteString
apiJson of
    Maybe ApiDefinition
Nothing     -> String -> Either String [String]
forall a b. a -> Either a b
Left String
"Invalid OpenAPI document"
    Just ApiDefinition
apiDef -> case ApiDefinition -> StdMethod -> String -> Maybe BodySchema
getRequestBodySchema ApiDefinition
apiDef StdMethod
method String
path of
        Maybe BodySchema
Nothing         -> String -> Either String [String]
forall a b. a -> Either a b
Left String
"Schema not found"
        Just BodySchema
bodySchema -> ApiDefinition -> BodySchema -> ByteString -> Either String [String]
validateJsonDocument ApiDefinition
apiDef BodySchema
bodySchema ByteString
body

validateResponseBody :: StdMethod -> FilePath -> Int -> L.ByteString -> L.ByteString -> Either String [String]
validateResponseBody :: StdMethod
-> String
-> Int
-> ByteString
-> ByteString
-> Either String [String]
validateResponseBody StdMethod
method String
path Int
statusCode' ByteString
apiJson ByteString
body = case ByteString -> Maybe ApiDefinition
toApiDefinition ByteString
apiJson of
    Maybe ApiDefinition
Nothing     -> String -> Either String [String]
forall a b. a -> Either a b
Left String
"Invalid OpenAPI document"
    Just ApiDefinition
apiDef -> case ApiDefinition -> StdMethod -> String -> Int -> Maybe BodySchema
getResponseBodySchema ApiDefinition
apiDef StdMethod
method String
path Int
statusCode' of
        Maybe BodySchema
Nothing         -> String -> Either String [String]
forall a b. a -> Either a b
Left String
"Schema not found"
        Just BodySchema
bodySchema -> ApiDefinition -> BodySchema -> ByteString -> Either String [String]
validateJsonDocument ApiDefinition
apiDef BodySchema
bodySchema ByteString
body