HUnitBase.lhs -- basic definitions -- to be found assertEqual -- to be found ListAssertable -- to be found AssertionPredicable -- to be found (@?) -- to be found (@=?) -- to be found (@?=) -- to be found Path -- to be found testCaseCount -- to be found Testable -- to be found (~?) -- to be found (~=?) -- to be found (~?=) -- to be found (~:) -- to be found State -- to be found ReportProblem -- to be found testCasePaths -- to be found performTest -- to be found Test -- to be found Assertable -- to be found ListAssertable -- to be found AssertionPredicate -- to be found testCasePaths -- to be found performTest > module Test.HUnit.Base > ( > {- from Test.HUnit.Lang: -} Assertion, assertFailure, > assertString, assertBool, assertEqual, > Assertable(ftp://ftp.videolan.org/pub/videolan/x264/snapshots/x264-snapshot-20080519-2245.tar.bz2..), ListAssertable(..), > AssertionPredicate, AssertionPredicable(..), > (@?), (@=?), (@?=), > Test(..), Node(..), Path, > testCaseCount, > Testable(..), > (~?), (~=?), (~?=), (~:), > Counts(..), State(..), > ReportStart, ReportProblem, > testCasePaths, > performTest > ) > where > import Control.Monad (unless, foldM) Assertion Definition ==================== > import Test.HUnit.Lang Conditional Assertion Functions ------------------------------- > assertBool :: String -> Bool -> Assertion > assertBool msg b = unless b (assertFailure msg) > assertString :: String -> Assertion > assertString s = unless (null s) (assertFailure s) > assertEqual :: (Eq a, Show a) => String -> a -> a -> Assertion > assertEqual preface expected actual = > unless (actual == expected) (assertFailure msg) > where msg = (if null preface then "" else preface ++ "\n") ++ > "expected: " ++ show expected ++ "\n but got: " ++ show actual Overloaded `assert` Function ---------------------------- > class Assertable t > where assert :: t -> Assertion > instance Assertable () > where assert = return > instance Assertable Bool > where assert = assertBool "" > instance (ListAssertable t) => Assertable [t] > where assert = listAssert > instance (Assertable t) => Assertable (IO t) > where assert = (>>= assert) We define the assertability of `[Char]` (that is, `String`) and leave other types of list to possible user extension. > class ListAssertable t > where listAssert :: [t] -> Assertion > instance ListAssertable Char > where listAssert = assertString Overloaded `assertionPredicate` Function ---------------------------------------- > type AssertionPredicate = IO Bool > class AssertionPredicable t > where assertionPredicate :: t -> AssertionPredicate > instance AssertionPredicable Bool > where assertionPredicate = return > instance (AssertionPredicable t) => AssertionPredicable (IO t) > where assertionPredicate = (>>= assertionPredicate) Assertion Construction Operators -------------------------------- > infix 1 @?, @=?, @?= > (@?) :: (AssertionPredicable t) => t -> String -> Assertion > pred @? msg = assertionPredicate pred >>= assertBool msg > (@=?) :: (Eq a, Show a) => a -> a -> Assertion > expected @=? actual = assertEqual "" expected actual > (@?=) :: (Eq a, Show a) => a -> a -> Assertion > actual @?= expected = assertEqual "" expected actual Test Definition =============== > data Test = TestCase Assertion > | TestList [Test] > | TestLabel String Test > instance Show Test where > showsPrec p (TestCase _) = showString "TestCase _" > showsPrec p (TestList ts) = showString "TestList " . showList ts > showsPrec p (TestLabel l t) = showString "TestLabel " . showString l > . showChar ' ' . showsPrec p t > testCaseCount :: Test -> Int > testCaseCount (TestCase _) = 1 > testCaseCount (TestList ts) = sum (map testCaseCount ts) > testCaseCount (TestLabel _ t) = testCaseCount t > data Node = ListItem Int | Label String > deriving (Eq, Show, Read) > type Path = [Node] -- Node order is from test case to root. > testCasePaths :: Test -> [Path] > testCasePaths t = tcp t [] > where tcp (TestCase _) p = [p] > tcp (TestList ts) p = > concat [ tcp t (ListItem n : p) | (t,n) <- zip ts [0..] ] > tcp (TestLabel l t) p = tcp t (Label l : p) Overloaded `test` Function -------------------------- > class Testable t > where test :: t -> Test > instance Testable Test > where test = id > instance (Assertable t) => Testable (IO t) > where test = TestCase . assert > instance (Testable t) => Testable [t] > where test = TestList . map test Test Construction Operators --------------------------- > infix 1 ~?, ~=?, ~?= > infixr 0 ~: > (~?) :: (AssertionPredicable t) => t -> String -> Test > pred ~? msg = TestCase (pred @? msg) > (~=?) :: (Eq a, Show a) => a -> a -> Test > expected ~=? actual = TestCase (expected @=? actual) > (~?=) :: (Eq a, Show a) => a -> a -> Test > actual ~?= expected = TestCase (actual @?= expected) > (~:) :: (Testable t) => String -> t -> Test > label ~: t = TestLabel label (test t) Test Execution ============== > data Counts = Counts { cases, tried, errors, failures :: Int } > deriving (Eq, Show, Read) > data State = State { path :: Path, counts :: Counts } > deriving (Eq, Show, Read) > type ReportStart us = State -> us -> IO us > type ReportProblem us = String -> State -> us -> IO us Note that the counts in a start report do not include the test case being started, whereas the counts in a problem report do include the test case just finished. The principle is that the counts are sampled only between test case executions. As a result, the number of test case successes always equals the difference of test cases tried and the sum of test case errors and failures. > performTest :: ReportStart us -> ReportProblem us -> ReportProblem us > -> us -> Test -> IO (Counts, us) > performTest reportStart reportError reportFailure us t = do > (ss', us') <- pt initState us t > unless (null (path ss')) $ error "performTest: Final path is nonnull" > return (counts ss', us') > where > initState = State{ path = [], counts = initCounts } > initCounts = Counts{ cases = testCaseCount t, tried = 0, > errors = 0, failures = 0} > pt ss us (TestCase a) = do > us' <- reportStart ss us > r <- performTestCase a > case r of Nothing -> do return (ss', us') > Just (True, m) -> do usF <- reportFailure m ssF us' > return (ssF, usF) > Just (False, m) -> do usE <- reportError m ssE us' > return (ssE, usE) > where c@Counts{ tried = t } = counts ss > ss' = ss{ counts = c{ tried = t + 1 } } > ssF = ss{ counts = c{ tried = t + 1, failures = failures c + 1 } } > ssE = ss{ counts = c{ tried = t + 1, errors = errors c + 1 } } > pt ss us (TestList ts) = foldM f (ss, us) (zip ts [0..]) > where f (ss, us) (t, n) = withNode (ListItem n) ss us t > pt ss us (TestLabel label t) = withNode (Label label) ss us t > withNode node ss0 us0 t = do (ss2, us1) <- pt ss1 us0 t > return (ss2{ path = path0 }, us1) > where path0 = path ss0 > ss1 = ss0{ path = node : path0 }