{-# LANGUAGE ViewPatterns #-}
module Test.Hspec.Wai.Matcher (
  ResponseMatcher(..)
, MatchHeader(..)
, MatchBody(..)
, Body
, (<:>)
, bodyEquals
, match
, formatHeader
) where

import           Prelude ()
import           Prelude.Compat

import           Control.Monad
import           Data.Maybe
import           Data.String
import           Data.Text.Lazy.Encoding (encodeUtf8)
import qualified Data.Text.Lazy as T
import           Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as LB
import           Network.HTTP.Types
import           Network.Wai.Test

import           Test.Hspec.Wai.Util

type Body = LB.ByteString

data ResponseMatcher = ResponseMatcher {
  ResponseMatcher -> Int
matchStatus :: Int
, ResponseMatcher -> [MatchHeader]
matchHeaders :: [MatchHeader]
, ResponseMatcher -> MatchBody
matchBody :: MatchBody
}

data MatchHeader = MatchHeader ([Header] -> Body -> Maybe String)

data MatchBody = MatchBody ([Header] -> Body -> Maybe String)

bodyEquals :: Body -> MatchBody
bodyEquals :: Body -> MatchBody
bodyEquals Body
body = ([Header] -> Body -> Maybe String) -> MatchBody
MatchBody (\[Header]
_ Body
actual -> Body -> Body -> Maybe String
bodyMatcher Body
actual Body
body)
  where
    bodyMatcher :: Body -> Body -> Maybe String
    bodyMatcher :: Body -> Body -> Maybe String
bodyMatcher (Body -> ByteString
toStrict -> ByteString
actual) (Body -> ByteString
toStrict -> ByteString
expected) = String -> String -> String -> String
actualExpected String
"body mismatch:" String
actual_ String
expected_ String -> Maybe () -> Maybe String
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
actual ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
expected)
      where
        (String
actual_, String
expected_) = case (ByteString -> Maybe String
safeToString ByteString
actual, ByteString -> Maybe String
safeToString ByteString
expected) of
          (Just String
x, Just String
y) -> (String
x, String
y)
          (Maybe String, Maybe String)
_ -> (ByteString -> String
forall a. Show a => a -> String
show ByteString
actual, ByteString -> String
forall a. Show a => a -> String
show ByteString
expected)

matchAny :: MatchBody
matchAny :: MatchBody
matchAny = ([Header] -> Body -> Maybe String) -> MatchBody
MatchBody (\[Header]
_ Body
_ -> Maybe String
forall a. Maybe a
Nothing)

instance IsString MatchBody where
  fromString :: String -> MatchBody
fromString = Body -> MatchBody
bodyEquals (Body -> MatchBody) -> (String -> Body) -> String -> MatchBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Body
encodeUtf8 (Text -> Body) -> (String -> Text) -> String -> Body
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

instance IsString ResponseMatcher where
  fromString :: String -> ResponseMatcher
fromString = Int -> [MatchHeader] -> MatchBody -> ResponseMatcher
ResponseMatcher Int
200 [] (MatchBody -> ResponseMatcher)
-> (String -> MatchBody) -> String -> ResponseMatcher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> MatchBody
forall a. IsString a => String -> a
fromString

instance Num ResponseMatcher where
  fromInteger :: Integer -> ResponseMatcher
fromInteger Integer
n = Int -> [MatchHeader] -> MatchBody -> ResponseMatcher
ResponseMatcher (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
n) [] MatchBody
matchAny
  + :: ResponseMatcher -> ResponseMatcher -> ResponseMatcher
(+) =    String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (+)"
  (-) =    String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (-)"
  * :: ResponseMatcher -> ResponseMatcher -> ResponseMatcher
(*) =    String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (*)"
  abs :: ResponseMatcher -> ResponseMatcher
abs =    String -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support `abs`"
  signum :: ResponseMatcher -> ResponseMatcher
signum = String -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support `signum`"

match :: SResponse -> ResponseMatcher -> Maybe String
match :: SResponse -> ResponseMatcher -> Maybe String
match (SResponse (Status Int
status ByteString
_) [Header]
headers Body
body) (ResponseMatcher Int
expectedStatus [MatchHeader]
expectedHeaders (MatchBody [Header] -> Body -> Maybe String
bodyMatcher)) = [Maybe String] -> Maybe String
forall a. Monoid a => [a] -> a
mconcat [
    String -> String -> String -> String
actualExpected String
"status mismatch:" (Int -> String
forall a. Show a => a -> String
show Int
status) (Int -> String
forall a. Show a => a -> String
show Int
expectedStatus) String -> Maybe () -> Maybe String
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
status Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
expectedStatus)
  , [Header] -> Body -> [MatchHeader] -> Maybe String
checkHeaders [Header]
headers Body
body [MatchHeader]
expectedHeaders
  , [Header] -> Body -> Maybe String
bodyMatcher [Header]
headers Body
body
  ]

actualExpected :: String -> String -> String -> String
actualExpected :: String -> String -> String -> String
actualExpected String
message String
actual String
expected = [String] -> String
unlines [
    String
message
  , String
"  expected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
expected
  , String
"  but got:  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
actual
  ]

checkHeaders :: [Header] -> Body -> [MatchHeader] -> Maybe String
checkHeaders :: [Header] -> Body -> [MatchHeader] -> Maybe String
checkHeaders [Header]
headers Body
body [MatchHeader]
m = case [MatchHeader] -> [String]
go [MatchHeader]
m of
    [] -> Maybe String
forall a. Maybe a
Nothing
    [String]
xs -> String -> Maybe String
forall a. a -> Maybe a
Just ([String] -> String
forall a. Monoid a => [a] -> a
mconcat [String]
xs String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"the actual headers were:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unlines ((Header -> String) -> [Header] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Header -> String
formatHeader [Header]
headers))
  where
    go :: [MatchHeader] -> [String]
go = [Maybe String] -> [String]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe String] -> [String])
-> ([MatchHeader] -> [Maybe String]) -> [MatchHeader] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MatchHeader -> Maybe String) -> [MatchHeader] -> [Maybe String]
forall a b. (a -> b) -> [a] -> [b]
map (\(MatchHeader [Header] -> Body -> Maybe String
p) -> [Header] -> Body -> Maybe String
p [Header]
headers Body
body)

(<:>) :: HeaderName -> ByteString -> MatchHeader
HeaderName
name <:> :: HeaderName -> ByteString -> MatchHeader
<:> ByteString
value = ([Header] -> Body -> Maybe String) -> MatchHeader
MatchHeader (([Header] -> Body -> Maybe String) -> MatchHeader)
-> ([Header] -> Body -> Maybe String) -> MatchHeader
forall a b. (a -> b) -> a -> b
$ \[Header]
headers Body
_body -> Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Header
header Header -> [Header] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Header]
headers) Maybe () -> Maybe String -> Maybe String
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String)
-> ([String] -> String) -> [String] -> Maybe String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines) [
    String
"missing header:"
  , Header -> String
formatHeader Header
header
  ]
  where
    header :: Header
header = (HeaderName
name, ByteString
value)