module Test.Framework.HUnitWrapper (
HU.Assertion,
assertBool_, assertEqual_, assertEqualNoShow_, assertNotNull_, assertNull_,
assertSetEqual_, assertThrows_,
assertFailure,
HU.Test(..), runTestTT,
joinPathElems, showPathLabel
) where
import System.IO ( stderr )
import Data.List ( (\\) )
import Control.Exception
import Data.Version
import qualified Test.HUnit as HU
import Test.Framework.Location
import Test.Framework.Configuration
assertFailure :: String -> IO ()
assertFailure s =
if ghcVersion <= buggyVersion
then
error (hunitPrefix ++ s)
else HU.assertFailure s
where hunitPrefix = "HUnit:"
buggyVersion = Version [6,4,1] []
assertBool_ :: Location -> Bool -> HU.Assertion
assertBool_ loc False = assertFailure ("assert failed at " ++ showLoc loc)
assertBool_ loc True = return ()
assertEqual_ :: (Eq a, Show a) => Location -> a -> a -> HU.Assertion
assertEqual_ loc expected actual =
if expected /= actual
then assertFailure msg
else return ()
where msg = "assertEqual failed at " ++ showLoc loc ++
"\n expected: " ++ show expected ++ "\n but got: " ++ show actual
assertEqualNoShow_ :: Eq a => Location -> a -> a -> HU.Assertion
assertEqualNoShow_ loc expected actual =
if expected /= actual
then assertFailure ("assertEqualNoShow failed at " ++ showLoc loc)
else return ()
assertSetEqual_ :: (Eq a, Show a) => Location -> [a] -> [a] -> HU.Assertion
assertSetEqual_ loc expected actual =
let ne = length expected
na = length actual
in case () of
_| ne /= na ->
assertFailure ("assertSetEqual failed at " ++ showLoc loc
++ "\n expected length: " ++ show ne
++ "\n actual length: " ++ show na)
| not (unorderedEq expected actual) ->
assertFailure ("assertSetEqual failed at " ++ showLoc loc
++ "\n expected: " ++ show expected
++ "\n actual: " ++ show actual)
| otherwise -> return ()
where unorderedEq l1 l2 =
null (l1 \\ l2) && null (l2 \\ l1)
assertNotNull_ :: Location -> [a] -> HU.Assertion
assertNotNull_ loc [] = assertFailure ("assertNotNull failed at " ++ showLoc loc)
assertNotNull_ _ (_:_) = return ()
assertNull_ :: Location -> [a] -> HU.Assertion
assertNull_ loc (_:_) = assertFailure ("assertNull failed at " ++ showLoc loc)
assertNull_ loc [] = return ()
assertThrows_ :: Location -> IO a -> (Exception -> Bool) -> HU.Assertion
assertThrows_ loc io f =
do res <- try io
case res of
Right _ -> assertFailure ("assertThrows failed at " ++ showLoc loc ++
": no exception was thrown")
Left e -> if f e then return ()
else assertFailure ("assertThrows failed at " ++
showLoc loc ++
": wrong exception was thrown: " ++
show e)
runTestText :: HU.PutText st -> HU.Test -> IO (HU.Counts, st)
runTestText (HU.PutText put us) t = do
put allTestsStr True us
(counts, us') <- HU.performTest reportStart reportError reportFailure us t
us'' <- put (HU.showCounts counts) True us'
return (counts, us'')
where
allTestsStr = unlines ("All tests:" :
map (\p -> " " ++ showPath p) (HU.testCasePaths t))
reportStart ss us = put (HU.showCounts (HU.counts ss)) False us
reportError = reportProblem "Error:" "Error in: "
reportFailure = reportProblem "Failure:" "Failure in: "
reportProblem p0 p1 msg ss us = put line True us
where line = "### " ++ kind ++ path' ++ '\n' : msg ++ "\n"
kind = if null path' then p0 else p1
path' = showPath (HU.path ss)
showPath :: HU.Path -> String
showPath [] = ""
showPath nodes = foldr1 joinPathElems
(map showNode (filterNodes (reverse nodes)))
where showNode (HU.ListItem n) = show n
showNode (HU.Label label) = showPathLabel label
filterNodes (HU.ListItem _ : l@(HU.Label _) : rest) =
l : filterNodes rest
filterNodes [] = []
filterNodes (x:rest) = x : filterNodes rest
joinPathElems :: String -> String -> String
joinPathElems s1 s2 = s1 ++ ":" ++ s2
showPathLabel :: String -> String
showPathLabel s =
let ss = show s
in if ':' `elem` s || "\"" ++ s ++ "\"" /= ss then ss else s
runTestTT :: HU.Test -> IO HU.Counts
runTestTT t = do (counts, _) <- runTestText (HU.putTextToHandle stderr False) t
return counts