module Yesod.Test (
runTests, describe, it, Specs, OneSpec,
post, post_, get, get_, doRequest,
byName, fileByName,
byLabel, fileByLabel,
addNonce, addNonce_,
runDB,
assertEqual, assertHeader, assertNoHeader, statusIs, bodyEquals, bodyContains,
htmlAllContain, htmlCount,
printBody, printMatches,
htmlQuery, parseHTML, withResponse
)
where
import qualified Test.Hspec.Core as Core
import qualified Test.Hspec.Runner as Runner
import qualified Data.List as DL
import qualified Data.Maybe as DY
import qualified Data.ByteString.Char8 as BS8
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.ByteString.Lazy.Char8 as BSL8
import qualified Test.HUnit as HUnit
import qualified Test.Hspec.HUnit ()
import qualified Network.HTTP.Types as H
import qualified Network.Socket.Internal as Sock
import Data.CaseInsensitive (CI)
import Text.XML.HXT.Core hiding (app, err)
import Network.Wai
import Network.Wai.Test hiding (assertHeader, assertNoHeader)
import qualified Control.Monad.Trans.State as ST
import Control.Monad.IO.Class
import System.IO
import Yesod.Test.TransversingCSS
import Database.Persist.GenericSql
import Data.Monoid (mappend)
import qualified Data.Text.Lazy as TL
import Data.Text.Lazy.Encoding (encodeUtf8, decodeUtf8)
data SpecsData = SpecsData Application ConnectionPool [Core.Spec]
type Specs = ST.StateT SpecsData IO ()
data OneSpecData = OneSpecData Application ConnectionPool CookieValue (Maybe SResponse)
type OneSpec = ST.StateT OneSpecData IO
data RequestBuilderData = RequestBuilderData [RequestPart] (Maybe SResponse)
data RequestPart
= ReqPlainPart String String
| ReqFilePart String FilePath BSL8.ByteString String
type RequestBuilder = ST.StateT RequestBuilderData IO
class HoldsResponse a where
readResponse :: a -> Maybe SResponse
instance HoldsResponse OneSpecData where
readResponse (OneSpecData _ _ _ x) = x
instance HoldsResponse RequestBuilderData where
readResponse (RequestBuilderData _ x) = x
type CookieValue = H.Ascii
runTests :: Application -> ConnectionPool -> Specs -> IO a
runTests app connection specsDef = do
(SpecsData _ _ specs) <- ST.execStateT specsDef (SpecsData app connection [])
Runner.hspecX specs
describe :: String -> Specs -> Specs
describe label action = do
sData <- ST.get
SpecsData app conn specs <- liftIO $ ST.execStateT action sData
ST.put $ SpecsData app conn [Core.describe label specs]
it :: String -> OneSpec () -> Specs
it label action = do
SpecsData app conn specs <- ST.get
let spec = Core.it label $ do
_ <- ST.execStateT action $ OneSpecData app conn "" Nothing
return ()
ST.put $ SpecsData app conn $ spec : specs
withResponse :: HoldsResponse a => (SResponse -> ST.StateT a IO b) -> ST.StateT a IO b
withResponse f = maybe err f =<< fmap readResponse ST.get
where err = failure "There was no response, you should make a request"
parseHTML :: Html -> LA XmlTree a -> [a]
parseHTML html p = runLA (hread >>> p ) (TL.unpack $ decodeUtf8 html)
htmlQuery :: HoldsResponse a => Query -> ST.StateT a IO [Html]
htmlQuery query = withResponse $ \ res ->
case findBySelector (simpleBody res) query of
Left err -> failure $ T.unpack query ++ " did not parse: " ++ (show err)
Right matches -> return $ map (encodeUtf8 . TL.pack) matches
assertEqual :: (Eq a) => String -> a -> a -> OneSpec ()
assertEqual msg a b = liftIO $ HUnit.assertBool msg (a == b)
statusIs :: HoldsResponse a => Int -> ST.StateT a IO ()
statusIs number = withResponse $ \ SResponse { simpleStatus = s } ->
liftIO $ flip HUnit.assertBool (H.statusCode s == number) $ concat
[ "Expected status was ", show number
, " but received status was ", show $ H.statusCode s
]
assertHeader :: HoldsResponse a => CI BS8.ByteString -> BS8.ByteString -> ST.StateT a IO ()
assertHeader header value = withResponse $ \ SResponse { simpleHeaders = h } ->
case lookup header h of
Nothing -> failure $ concat
[ "Expected header "
, show header
, " to be "
, show value
, ", but it was not present"
]
Just value' -> liftIO $ flip HUnit.assertBool (value == value') $ concat
[ "Expected header "
, show header
, " to be "
, show value
, ", but received "
, show value'
]
assertNoHeader :: HoldsResponse a => CI BS8.ByteString -> ST.StateT a IO ()
assertNoHeader header = withResponse $ \ SResponse { simpleHeaders = h } ->
case lookup header h of
Nothing -> return ()
Just s -> failure $ concat
[ "Unexpected header "
, show header
, " containing "
, show s
]
bodyEquals :: HoldsResponse a => String -> ST.StateT a IO ()
bodyEquals text = withResponse $ \ res ->
liftIO $ HUnit.assertBool ("Expected body to equal " ++ text) $
(simpleBody res) == BSL8.pack text
bodyContains :: HoldsResponse a => String -> ST.StateT a IO ()
bodyContains text = withResponse $ \ res ->
liftIO $ HUnit.assertBool ("Expected body to contain " ++ text) $
(simpleBody res) `contains` text
contains :: BSL8.ByteString -> String -> Bool
contains a b = DL.isInfixOf b (BSL8.unpack a)
htmlAllContain :: HoldsResponse a => Query -> String -> ST.StateT a IO ()
htmlAllContain query search = do
matches <- htmlQuery query
case matches of
[] -> failure $ "Nothing matched css query: "++T.unpack query
_ -> liftIO $ HUnit.assertBool ("Not all "++T.unpack query++" contain "++search) $
DL.all (DL.isInfixOf search) (map (TL.unpack . decodeUtf8) matches)
htmlCount :: HoldsResponse a => Query -> Int -> ST.StateT a IO ()
htmlCount query count = do
matches <- fmap DL.length $ htmlQuery query
liftIO $ flip HUnit.assertBool (matches == count)
("Expected "++(show count)++" elements to match "++T.unpack query++", found "++(show matches))
printBody :: HoldsResponse a => ST.StateT a IO ()
printBody = withResponse $ \ SResponse { simpleBody = b } ->
liftIO $ hPutStrLn stderr $ BSL8.unpack b
printMatches :: HoldsResponse a => Query -> ST.StateT a IO ()
printMatches query = do
matches <- htmlQuery query
liftIO $ hPutStrLn stderr $ show matches
byName :: String -> String -> RequestBuilder ()
byName name value = do
RequestBuilderData parts r <- ST.get
ST.put $ RequestBuilderData ((ReqPlainPart name value):parts) r
fileByName :: String -> FilePath -> String -> RequestBuilder ()
fileByName name path mimetype = do
RequestBuilderData parts r <- ST.get
contents <- liftIO $ BSL8.readFile path
ST.put $ RequestBuilderData ((ReqFilePart name path contents mimetype):parts) r
nameFromLabel :: String -> RequestBuilder String
nameFromLabel label = withResponse $ \ res -> do
let
body = simpleBody res
escaped = escapeHtmlEntities label
mfor = parseHTML body $ deep $ hasName "label"
>>> filterA (xshow this >>> mkText >>> hasText (DL.isInfixOf escaped))
>>> getAttrValue "for"
case mfor of
for:[] -> do
let mname = parseHTML body $ deep $ hasAttrValue "id" (==for) >>> getAttrValue "name"
case mname of
"":_ -> failure $ "Label "++label++" resolved to id "++for++" which was not found. "
name:_ -> return name
_ -> failure $ "More than one input with id " ++ for
[] -> failure $ "No label contained: "++label
_ -> failure $ "More than one label contained "++label
escapeHtmlEntities :: String -> String
escapeHtmlEntities "" = ""
escapeHtmlEntities (c:cs) = case c of
'<' -> '&' : 'l' : 't' : ';' : escapeHtmlEntities cs
'>' -> '&' : 'g' : 't' : ';' : escapeHtmlEntities cs
'&' -> '&' : 'a' : 'm' : 'p' : ';' : escapeHtmlEntities cs
'"' -> '&' : 'q' : 'u' : 'o' : 't' : ';' : escapeHtmlEntities cs
'\'' -> '&' : '#' : '3' : '9' : ';' : escapeHtmlEntities cs
x -> x : escapeHtmlEntities cs
byLabel :: String -> String -> RequestBuilder ()
byLabel label value = do
name <- nameFromLabel label
byName name value
fileByLabel :: String -> FilePath -> String -> RequestBuilder ()
fileByLabel label path mime = do
name <- nameFromLabel label
fileByName name path mime
addNonce_ :: Query -> RequestBuilder ()
addNonce_ scope = do
matches <- htmlQuery $ scope `mappend` "input[name=_token][type=hidden][value]"
case matches of
[] -> failure $ "No nonce found in the current page"
element:[] -> byName "_token" $ head $ parseHTML element $ getAttrValue "value"
_ -> failure $ "More than one nonce found in the page"
addNonce :: RequestBuilder ()
addNonce = addNonce_ ""
post :: BS8.ByteString -> RequestBuilder () -> OneSpec ()
post url paramsBuild = do
doRequest "POST" url paramsBuild
post_ :: BS8.ByteString -> OneSpec ()
post_ = flip post $ return ()
get :: BS8.ByteString -> RequestBuilder () -> OneSpec ()
get url paramsBuild = doRequest "GET" url paramsBuild
get_ :: BS8.ByteString -> OneSpec ()
get_ = flip get $ return ()
doRequest :: H.Method -> BS8.ByteString -> RequestBuilder a -> OneSpec ()
doRequest method url paramsBuild = do
OneSpecData app conn cookie mRes <- ST.get
RequestBuilderData parts _ <- liftIO $ ST.execStateT paramsBuild $ RequestBuilderData [] mRes
let req = if DL.any isFile parts
then makeMultipart cookie parts
else makeSinglepart cookie parts
response <- liftIO $ runSession (srequest req) app
let cookie' = DY.fromMaybe cookie $ fmap snd $ DL.find (("Set-Cookie"==) . fst) $ simpleHeaders response
ST.put $ OneSpecData app conn cookie' (Just response)
where
isFile (ReqFilePart _ _ _ _) = True
isFile _ = False
boundary :: String
boundary = "*******noneedtomakethisrandom"
separator = BS8.concat ["--", BS8.pack boundary, "\r\n"]
makeMultipart cookie parts =
flip SRequest (BSL8.fromChunks [multiPartBody parts]) $ mkRequest
[ ("Cookie", cookie)
, ("Content-Type", BS8.pack $ "multipart/form-data; boundary=" ++ boundary)]
multiPartBody parts =
BS8.concat $ separator : [BS8.concat [multipartPart p, separator] | p <- parts]
multipartPart (ReqPlainPart k v) = BS8.concat
[ "Content-Disposition: form-data; "
, "name=\"", (BS8.pack k), "\"\r\n\r\n"
, (BS8.pack v), "\r\n"]
multipartPart (ReqFilePart k v bytes mime) = BS8.concat
[ "Content-Disposition: form-data; "
, "name=\"", BS8.pack k, "\"; "
, "filename=\"", BS8.pack v, "\"\r\n"
, "Content-Type: ", BS8.pack mime, "\r\n\r\n"
, BS8.concat $ BSL8.toChunks bytes, "\r\n"]
makeSinglepart cookie parts = SRequest (mkRequest
[("Cookie",cookie), ("Content-Type", "application/x-www-form-urlencoded")]) $
BSL8.pack $ DL.concat $ DL.intersperse "&" $ map singlepartPart parts
singlepartPart (ReqFilePart _ _ _ _) = ""
singlepartPart (ReqPlainPart k v) = concat [k,"=",v]
mkRequest headers = defaultRequest
{ requestMethod = method
, remoteHost = Sock.SockAddrInet 1 2
, requestHeaders = headers
, rawPathInfo = url
, pathInfo = DL.filter (/="") $ T.split (== '/') $ TE.decodeUtf8 url
}
runDB :: SqlPersist IO a -> OneSpec a
runDB query = do
OneSpecData _ pool _ _ <- ST.get
liftIO $ runSqlPool query pool
failure :: (MonadIO a) => String -> a b
failure reason = (liftIO $ HUnit.assertFailure reason) >> error ""