module HyLo.Test ( ModuleName, TestName, TestCase, TestResult(..), UnitTest,
                   runTest, runTestWith, Config(..), defaultConfig,
                   testSuite, testModule, stopOnError, runCompletely )

where

import Test.QuickCheck

import Data.List     ( sort, group, intercalate )
import Text.Printf   ( printf )
import System.Random ( StdGen, newStdGen, split )
import HyLo.Util     ( sequenceUntil )

type ModuleName = String
type TestName   = String
type Align      = Int
type TestCase   = TestName -> Align -> IO TestResult
type UnitTest   = [(TestName, TestCase)]


data TestResult = OK | EXHAUSTED | FAILED deriving (Eq, Show, Enum)

runTest :: Testable a => a -> TestCase
runTest = runTestWith defaultConfig

runTestWith :: Testable a => Config -> a -> TestCase
runTestWith cfg = mycheck cfg

testSuite :: [(ModuleName, UnitTest)] -> [IO TestResult]
testSuite test_suite = concatMap modulePlan test_suite
    where max_len = maximum . map (length . fst) . concatMap snd $ test_suite
          modulePlan (m,ut) = mkTestPlan (max_len + 17) $ prependModName m ut

testModule :: UnitTest -> [IO TestResult]
testModule unit_test = mkTestPlan (max_len + 17) unit_test
    where max_len = maximum . map (length . fst) $ unit_test

stopOnError :: ([a] -> [IO TestResult]) -> [a] ->IO [TestResult]
stopOnError f ts = sequenceUntil (not . (== OK)) $ f ts

runCompletely :: ([a] -> [IO TestResult]) -> [a] -> IO [TestResult]
runCompletely f ts = sequence $ f ts

mkTestPlan :: Align -> UnitTest -> [IO TestResult]
mkTestPlan max_len = map $ \(name, action) -> action name (max_len + 17)

prependModName :: ModuleName -> UnitTest -> UnitTest
prependModName _  []      = []
prependModName m (x:xs) = addHdr x : xs
    where addHdr (n,a) = (n,\n' i -> printf "\nTesting module %s\n" m >> a n' i)


-- What follows was "borrowed" from Test.QuickCheck
-- and adapted to return the test result and do some formatting

mycheck :: Testable a => Config -> a -> TestCase
mycheck config a testName align =
  do rnd <- newStdGen
     putStr ("  " ++ take (align + 10) (testName ++ repeat '.'))
     tests config (evaluate a) rnd 0 0 align []

tests :: Config
      -> Gen Result
      -> StdGen
      -> Int
      -> Align
      -> Int
      -> [[String]]
      -> IO TestResult
tests config gen rnd0 ntest nfail indent stamps
  | ntest == configMaxTest config = do done "OK, passed" ntest indent stamps
                                       return OK
    --
  | nfail == configMaxFail config = do done "Arguments exhausted after"
                                            ntest
                                            indent
                                            stamps
                                       return EXHAUSTED
    --
  | otherwise               =
      do putStr (configEvery config ntest (arguments result))
         case ok result of
           Nothing    ->
             tests config gen rnd1 ntest (nfail+1) indent stamps
           Just True  ->
             tests config gen rnd1 (ntest+1) nfail indent (stamp result:stamps)
           Just False -> do
             putStr ( "Falsifiable, after "
                   ++ show ntest
                   ++ " tests:\n"
                   ++ unlines (arguments result)
                    )
             return FAILED
     where
      result      = generate (configSize config ntest) rnd2 gen
      (rnd1,rnd2) = split rnd0

done :: String -> Int -> Int -> [[String]] -> IO ()
done mesg ntest indent stamps =
  putStr ( mesg ++ " " ++ show ntest ++ " tests" ++ table )
 where
  table = display
        . map entry
        . reverse
        . sort
        . map pairLength
        . group
        . sort
        . filter (not . null)
        $ stamps
  --
  display []  = ".\n"
  display [x] = " (" ++ x ++ ").\n"
  display xs  = ".\n" ++ unlines (map (\s -> blank ++ s ++ ".") xs)
  --
  blank = replicate indent ' '
  --
  pairLength xss@(xs:_) = (length xss, xs)
  pairLength _          = error "pairLength: can't happen!"
  --
  entry (n, xs)         = percentage n ntest
                        ++ " "
                        ++ intercalate ", " xs
  --
  percentage n m        = show ((100 * n) `div` m) ++ "%"