module Test.Hspec.Wai.Matcher (
ResponseMatcher(..)
, MatchHeader(..)
, (<:>)
, match
) where
import Prelude ()
import Prelude.Compat
import Control.Monad
import Data.Maybe
import Data.String
import Data.Text.Lazy.Encoding (encodeUtf8)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as LB
import Network.HTTP.Types
import Network.Wai.Test
import Test.Hspec.Wai.Util
data ResponseMatcher = ResponseMatcher {
matchStatus :: Int
, matchHeaders :: [MatchHeader]
, matchBody :: Maybe LB.ByteString
}
data MatchHeader = MatchHeader ([Header] -> Maybe String)
instance IsString ResponseMatcher where
fromString s = ResponseMatcher 200 [] (Just . encodeUtf8 . fromString $ s)
instance Num ResponseMatcher where
fromInteger n = ResponseMatcher (fromInteger n) [] Nothing
(+) = error "ResponseMatcher does not support (+)"
() = error "ResponseMatcher does not support (-)"
(*) = error "ResponseMatcher does not support (*)"
abs = error "ResponseMatcher does not support `abs`"
signum = error "ResponseMatcher does not support `signum`"
match :: SResponse -> ResponseMatcher -> Maybe String
match (SResponse (Status status _) headers body) (ResponseMatcher expectedStatus expectedHeaders expectedBody) = mconcat [
actualExpected "status mismatch:" (show status) (show expectedStatus) <$ guard (status /= expectedStatus)
, checkHeaders headers expectedHeaders
, expectedBody >>= matchBody_ body
]
where
matchBody_ (toStrict -> actual) (toStrict -> expected) = actualExpected "body mismatch:" actual_ expected_ <$ guard (actual /= expected)
where
(actual_, expected_) = case (safeToString actual, safeToString expected) of
(Just x, Just y) -> (x, y)
_ -> (show actual, show expected)
actualExpected :: String -> String -> String -> String
actualExpected message actual expected = unlines [
message
, " expected: " ++ expected
, " but got: " ++ actual
]
checkHeaders :: [Header] -> [MatchHeader] -> Maybe String
checkHeaders headers m = case go m of
[] -> Nothing
xs -> Just (mconcat xs ++ "the actual headers were:\n" ++ unlines (map formatHeader headers))
where
go = catMaybes . map (\(MatchHeader p) -> p headers)
(<:>) :: HeaderName -> ByteString -> MatchHeader
name <:> value = MatchHeader $ \headers -> guard (header `notElem` headers) >> (Just . unlines) [
"missing header:"
, formatHeader header
]
where
header = (name, value)