module Test.HTTP (httpTest, session, get, getJSON, postForm, assert, Program, Session) where
import Network.Curl hiding (curlGetString)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception hiding (assert)
import Control.Monad
import Control.Monad.Error
import Control.Monad.Reader
import qualified Control.Monad.State.Strict as State
import Control.Monad.Trans
import Data.Char
import Data.IORef
import Data.List
import Data.Maybe
import GHC.Conc
import qualified Data.Aeson as Ae
import Safe (readMay)
import System.Console.GetOpt
import System.Environment
import System.Exit
import System.IO
import System.IO.Error
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.ByteString.Lazy (fromStrict)
type Program = ReaderT (TVar [Results]) IO
data SessionState = SessionState { sessionResults :: Results,
sessionBaseUrl :: String,
sessionCurl :: Curl }
type Session = State.StateT SessionState IO
type Results = [(String, Maybe String)]
httpTest :: Program () -> IO ()
httpTest m = withCurlDo $ do
resTV <- newTVarIO []
runReaderT m resTV
finalRes <- fmap concat $ readTVarIO resTV
mapM_ (putStrLn . ppRes) finalRes
if any (isJust . snd) finalRes
then exitWith $ ExitFailure 1
else exitWith $ ExitSuccess
ppRes (nm, Nothing) = "Pass: "++nm
ppRes (nm, Just reason) = "FAIL: "++nm++"; "++reason
session :: String
-> String
-> Session ()
-> Program ()
session sessionName baseURL m = do
c <- liftIO $ initialize
let state0 = SessionState [] baseURL c
liftIO $ setopts c [CurlCookieJar (sessionName++"_cookies"),
CurlFollowLocation True]
SessionState res _ _ <- liftIO $ State.execStateT m state0
res_tv <- ask
liftIO $ atomically $ do
others <- readTVar res_tv
writeTVar res_tv $ others ++ [res]
get :: String
-> Session String
get url = do
(code, res) <- getRaw url
when (code /= CurlOK) $
failTest ("GET "++url) (show code)
return res
getRaw :: String -> Session (CurlCode, String)
getRaw url = do
SessionState _ base c <- State.get
liftIO $ curlGetString c (base++url) []
getJSON :: Ae.FromJSON a =>
String
-> Session a
getJSON url = do
str <- get url
let Just x = Ae.decode' $ fromStrict $ encodeUtf8 $ T.pack str
return x
postForm :: String
-> [(String,String)]
-> Session String
postForm url fields = do
SessionState _ base c <- State.get
(code, res) <- liftIO $ curlPostString c (base++url) [] fields
when (code /= CurlOK) $
failTest ("POST "++url) (show code)
return res
assert :: String
-> Bool
-> Session ()
assert assName True =
passTest assName
assert assName False =
failTest assName "fail"
addTestResult p =
State.modify $ \s -> s { sessionResults = p : sessionResults s }
passTest tstNm = addTestResult (tstNm, Nothing)
failTest tstNm reason = addTestResult (tstNm, Just reason)
curlGetString :: Curl -> URLString
-> [CurlOption]
-> IO (CurlCode, String)
curlGetString h url opts = do
ref <- newIORef []
setopt h (CurlPostFields [])
setopt h (CurlPost False)
setopt h (CurlFailOnError True)
setDefaultSSLOpts h url
setopt h (CurlURL url)
setopt h (CurlWriteFunction (gatherOutput ref))
mapM_ (setopt h) opts
rc <- perform h
lss <- readIORef ref
return (rc, concat $ reverse lss)
curlPostString :: Curl -> URLString -> [CurlOption] -> [(String, String)] -> IO (CurlCode, String)
curlPostString h url opts fields = do
ref <- newIORef []
setopt h (CurlFollowLocation True)
setopt h (CurlFailOnError True)
setopt h (CurlPost True)
setopt h (CurlPostFields fields')
setDefaultSSLOpts h url
setopt h (CurlURL url)
setopt h (CurlWriteFunction (gatherOutput ref))
mapM_ (setopt h) opts
rc <- perform h
lss <- readIORef ref
return (rc, concat $ reverse lss)
where fields' = map (\(x,y) -> x ++ '=':y) fields
decode :: String -> String
decode [] = []
decode ('\\':'u':a:b:c:d:xs)
| isHexDigit a && isHexDigit b && isHexDigit c && isHexDigit d
= chr (hexToInt [a,b,c,d]) : decode xs
decode ('\\':'n': xs) = '\n' : decode xs
decode ('\\':'"': xs) = '"' : decode xs
decode ('&':'q':'u':'o':'t':';':xs) = '"' : decode xs
decode ('&':'g':'t':';':xs) = '>' : decode xs
decode ('&':'l':'t':';':xs) = '<' : decode xs
decode (x : xs) = x : decode xs
hexToInt :: String -> Int
hexToInt [] = 0
hexToInt [n] = digitToInt n
hexToInt (n:ns) = digitToInt n * 16 + hexToInt ns