{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE PackageImports #-} -- -- Copyright (c) 2009-2014 Stefan Wehr - http://www.stefanwehr.de -- -- This library is free software; you can redistribute it and/or -- modify it under the terms of the GNU Lesser General Public -- License as published by the Free Software Foundation; either -- version 2.1 of the License, or (at your option) any later version. -- -- This library is distributed in the hope that it will be useful, -- but WITHOUT ANY WARRANTY; without even the implied warranty of -- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -- Lesser General Public License for more details. -- -- You should have received a copy of the GNU Lesser General Public -- License along with this library; if not, write to the Free Software -- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA -- module Test.Framework.Preprocessor ( transform, progName, preprocessorTests ) where -- import Debug.Trace import Control.Monad import Data.Char import "haskell-lexer" Language.Haskell.Lexer import Language.Preprocessor.Cpphs ( runCpphs, CpphsOptions(..), BoolOptions(..), defaultCpphsOptions) import System.IO ( hPutStrLn, stderr ) import Test.HUnit hiding (State) import Control.Monad.State.Strict import qualified Data.List as List import Data.Maybe import Test.Framework.Location _DEBUG_ :: Bool _DEBUG_ = False progName :: String progName = "htfpp" htfModule :: String htfModule = "Test.Framework" mkName varName fullModuleName = "htf_" ++ map (\c -> if c == '.' then '_' else c) (fullModuleName ++ "." ++ (case varName of 'h':'t':'f':'_':s -> s s -> s)) thisModulesTestsFullName :: String -> String thisModulesTestsFullName = mkName thisModulesTestsName importedTestListFullName :: String -> String importedTestListFullName = mkName importedTestListName thisModulesTestsName :: String thisModulesTestsName = "htf_thisModulesTests" importedTestListName :: String importedTestListName = "htf_importedTests" nameDefines :: ModuleInfo -> [(String, String)] nameDefines info = [(thisModulesTestsName, thisModulesTestsFullName (mi_moduleNameWithDefault info)), (importedTestListName, importedTestListFullName (mi_moduleNameWithDefault info))] allAsserts :: [String] allAsserts = withGs ["assertBool" ,"assertEqual" ,"assertEqualPretty" ,"assertEqualNoShow" ,"assertNotEqual" ,"assertNotEqualPretty" ,"assertNotEqualNoShow" ,"assertListsEqualAsSets" ,"assertElem" ,"assertEmpty" ,"assertNotEmpty" ,"assertLeft" ,"assertLeftNoShow" ,"assertRight" ,"assertRightNoShow" ,"assertJust" ,"assertNothing" ,"assertNothingNoShow" ,"subAssert" ,"subAssertVerbose" ] ++ ["assertThrows" ,"assertThrowsSome" ,"assertThrowsIO" ,"assertThrowsSomeIO" ,"assertThrowsM" ,"assertThrowsSomeM"] where withGs l = concatMap (\s -> [s, 'g':s]) l assertDefines :: Bool -> String -> [(String, String)] assertDefines hunitBackwardsCompat prefix = concatMap fun allAsserts ++ [("assertFailure", expansion "assertFailure" "_")] where fun a = if hunitBackwardsCompat then [(a, expansion a "Verbose_"), (a ++ "HTF", expansion a "_")] else [(a, expansion a "_"), (a ++ "Verbose", expansion a "Verbose_")] expansion a suffix = "(" ++ prefix ++ a ++ suffix ++ " (" ++ prefix ++ "makeLoc __FILE__ __LINE__))" data ModuleInfo = ModuleInfo { mi_htfPrefix :: String , mi_htfImports :: [ImportDecl] , mi_defs :: [Definition] , mi_moduleName :: Maybe String } deriving (Show, Eq) mi_moduleNameWithDefault :: ModuleInfo -> String mi_moduleNameWithDefault = fromMaybe "Main" . mi_moduleName data ImportDecl = ImportDecl { imp_moduleName :: Name , imp_qualified :: Bool , imp_alias :: Maybe Name , imp_loc :: Location } deriving (Show, Eq) data Definition = TestDef String Location String | PropDef String Location String deriving (Eq, Show) type Name = String type PMA a = State ModuleInfo a setModName :: String -> PMA () setModName name = do oldName <- gets mi_moduleName when (isNothing oldName) $ modify $ \mi -> mi { mi_moduleName = Just name } addTestDef :: String -> String -> Location -> PMA () addTestDef name fullName loc = modify $ \mi -> mi { mi_defs = (TestDef name loc fullName) : mi_defs mi } addPropDef :: String -> String -> Location -> PMA () addPropDef name fullName loc = modify $ \mi -> mi { mi_defs = (PropDef name loc fullName) : mi_defs mi } addHtfImport :: ImportDecl -> PMA () addHtfImport decl = modify $ \mi -> mi { mi_htfImports = decl : mi_htfImports mi } setTestFrameworkImport :: String -> PMA () setTestFrameworkImport name = modify $ \mi -> mi { mi_htfPrefix = name } poorManAnalyzeTokens :: [LocToken] -> ModuleInfo poorManAnalyzeTokens toks = -- show toks `trace` let revRes = execState (loop toks) $ ModuleInfo { mi_htfPrefix = htfModule ++ "." , mi_htfImports = [] , mi_defs = [] , mi_moduleName = Nothing } in ModuleInfo { mi_htfPrefix = mi_htfPrefix revRes , mi_htfImports = reverse (mi_htfImports revRes) , mi_defs = reverse $ List.nubBy defEqByName (mi_defs revRes) , mi_moduleName = mi_moduleName revRes } where defEqByName (TestDef n1 _ _) (TestDef n2 _ _) = n1 == n2 defEqByName (PropDef n1 _ _) (PropDef n2 _ _) = n1 == n2 defEqByName _ _ = False loop toks = case toks of (Reservedid, (_, "module")) : rest -> case rest of (Conid, (_, name)):rest2 -> do setModName name loop rest2 (Qconid, (_, name)):rest2 -> do setModName name loop rest2 _ -> loop rest (Varid, (loc, name)) : rest | isStartOfLine loc -> case name of 't':'e':'s':'t':'_':shortName -> do addTestDef shortName name (locToLocation loc) loop rest 'p':'r':'o':'p':'_':shortName -> do addPropDef shortName name (locToLocation loc) loop rest _ -> loop rest | otherwise -> loop rest (Special, (loc, "import_HTF_TESTS")) : rest -> case parseImport loc rest of Just (imp, rest2) -> do addHtfImport imp loop rest2 Nothing -> loop rest (Reservedid, (loc, "import")) : rest -> do case parseImport loc rest of Nothing -> loop rest Just (imp, rest2) -> do when (imp_moduleName imp == htfModule) $ let prefix = case (imp_alias imp, imp_qualified imp) of (Just alias, True) -> alias (Nothing, True) -> imp_moduleName imp _ -> "" in setTestFrameworkImport (if null prefix then prefix else prefix ++ ".") loop rest2 _ : rest -> loop rest [] -> return () parseImport loc toks = do let (qualified, toks2) = case toks of (Varid, (_, "qualified")):rest -> (True, rest) _ -> (False, toks) (name, toks3) <- case toks2 of (Conid, (_, name)):rest -> return (name, rest) (Qconid, (_, name)):rest -> return (name, rest) _ -> fail "no import" let (mAlias, toks4) = case toks3 of (Varid, (_, "as")):(Conid, (_, alias)):rest -> (Just alias, rest) _ -> (Nothing, toks3) decl = ImportDecl { imp_moduleName = name , imp_qualified = qualified , imp_alias = mAlias , imp_loc = locToLocation loc } return (decl, toks4) locToLocation loc = makeLoc (l_file loc) (l_line loc) isStartOfLine loc = l_column loc == 1 cleanupTokens :: [PosToken] -> [PosToken] cleanupTokens toks = -- Remove whitespace tokens, remove comments, but replace -- 'import {-@ HTF_TESTS @-}' with a single -- token Special with value "import_HTF_TESTS" case toks of (Whitespace, _):rest -> cleanupTokens rest (NestedComment, (loc, "{-@ HTF_TESTS @-}")) : rest -> (Special, (loc, "import_HTF_TESTS")) : cleanupTokens rest tok:rest -> tok : cleanupTokens rest [] -> [] -- char -> ' (graphic<' | \> | space | escape<\&>) ' -- graphic -> small | large | symbol | digit | special | : | " | ' -- escape -> \ ( charesc | ascii | decimal | o octal | x hexadecimal ) -- charesc -> a | b | f | n | r | t | v | \ | " | ' | & -- ascii -> ^cntrl | NUL | SOH | STX | ETX | EOT | ENQ | ACK -- | BEL | BS | HT | LF | VT | FF | CR | SO | SI | DLE -- | DC1 | DC2 | DC3 | DC4 | NAK | SYN | ETB | CAN -- | EM | SUB | ESC | FS | GS | RS | US | SP | DEL -- cntrl -> ascLarge | @ | [ | \ | ] | ^ | _ -- decimal -> digit{digit} -- octal -> octit{octit} -- hexadecimal -> hexit{hexit} -- octit -> 0 | 1 | ... | 7 -- hexit -> digit | A | ... | F | a | ... | f -- special -> ( | ) | , | ; | [ | ] | `| { | } -- Purpose of cleanupInputString: filter out template Haskell quotes cleanupInputString :: String -> String cleanupInputString s = case s of c:'\'':'\'':x:rest | isSpace c && isUpper x -> -- TH type quote c:x:cleanupInputString rest c:'\'':'\'':d:rest | not (isAlphaNum c) || not (isAlphaNum d) -> c:'\'':'x':'\'':d:rest '\'':rest -> case characterLitRest rest of Just (restLit, rest') -> '\'':restLit ++ cleanupInputString rest' Nothing -> '\'':cleanupInputString rest c:'\'':x:rest -- TH name quote | isSpace c && isNothing (characterLitRest (x:rest)) && isLower x -> c:x:cleanupInputString rest c:rest | not (isSpace c) -> case span (== '\'') rest of (quotes, rest') -> c : quotes ++ cleanupInputString rest' c:rest -> c : cleanupInputString rest [] -> [] where characterLitRest s = -- expects that before the s there is a ' case s of '\\':'\'':'\'':rest -> Just ("\\''", rest) -- '\'' c:'\'':rest -> Just (c:"'", rest) -- regular character lit '\\':rest -> case span (/= '\'') rest of (esc,'\'':rest) -> Just (('\\':esc) ++ "'", rest) _ -> Just (s, "") -- should not happen _ -> Nothing cleanupInputStringTest = do flip mapM_ untouched $ \s -> let cleanedUp = cleanupInputString s in if s /= cleanedUp then assertFailure ("Cleanup of " ++ show s ++ " is wrong: " ++ show cleanedUp ++ ", expected that input and output are the same") else return () flip mapM_ touched $ \(input, output) -> let cleanedUp = cleanupInputString input in if output /= cleanedUp then assertFailure ("Cleanup of " ++ show input ++ " is wrong: " ++ show cleanedUp ++ ", expected " ++ show output) else return () where untouched = [" '0'", " '\\''", " ' '", " '\o761'", " '\BEL'", " '\^@' ", "' '", "fixed' ' '"] touched = [(" 'foo abc", " foo abc"), (" ''T ", " T ")] type LocToken = (Token,(Loc,String)) data Loc = Loc { l_file :: FilePath , l_line :: Int , l_column :: Int } deriving (Eq, Show) -- token stream should not contain whitespace fixPositions :: FilePath -> [PosToken] -> [LocToken] fixPositions originalFileName = loop Nothing where loop mPragma toks = case toks of [] -> [] (Varsym, (pos, "#")) : (Varid, (_, "line")) : (IntLit, (_, lineNo)) : (StringLit,(_, fileName)) : rest | column pos == 1 -> map (\(tt, (pos, x)) -> (tt, (fixPos Nothing pos, x))) (take 4 toks) ++ loop (Just (line pos, fileName, read lineNo)) rest (tt, (pos, x)) : rest -> (tt, (fixPos mPragma pos, x)) : loop mPragma rest fixPos mPragma pos = case mPragma of Nothing -> Loc { l_column = column pos , l_file = originalFileName , l_line = line pos } Just (lineActivated, fileName, lineNo) -> let offset = line pos - lineActivated - 1 in Loc { l_column = column pos , l_file = fileName , l_line = lineNo + offset } fixPositionsTest :: IO () fixPositionsTest = let toks = concatMap (\(f, i) -> f i) (zip [tok, linePragma "bar" 10, tok, tok, linePragma "foo" 99, tok] [1..]) fixedToks = fixPositions origFileName toks expectedToks = concat $ [tok' origFileName 1 ,linePragma' "bar" 10 2 ,tok' "bar" 10 ,tok' "bar" 11 ,linePragma' "foo" 99 5 ,tok' "foo" 99] in assertEqual (show expectedToks ++ "\n\n /= \n\n" ++ show toks) expectedToks fixedToks where origFileName = "spam" tok line = [(Varid, (Pos 0 line 1, "_"))] linePragma fname line lineHere = let pos = Pos 0 lineHere 1 in [(Varsym, (pos, "#")) ,(Varid, (pos, "line")) ,(IntLit, (pos, show line)) ,(StringLit, (pos, fname))] tok' fname line = let loc = Loc fname line 1 in [(Varid, (loc, "_"))] linePragma' fname line lineHere = let loc = Loc origFileName lineHere 1 in [(Varsym, (loc, "#")) ,(Varid, (loc, "line")) ,(IntLit, (loc, show line)) ,(StringLit,(loc, fname))] analyze :: FilePath -> String -> ModuleInfo analyze originalFileName input = poorManAnalyzeTokens (fixPositions originalFileName (cleanupTokens (lexerPass0 (cleanupInputString input)))) analyzeTests = [(unlines ["module FOO where" ,"import Test.Framework" ,"import {-@ HTF_TESTS @-} qualified Foo as Bar" ,"import {-@ HTF_TESTS @-} qualified Foo.X as Egg" ,"import {-@ HTF_TESTS @-} Foo.Y as Spam" ,"import {-@ HTF_TESTS @-} Foo.Z" ,"import {-@ HTF_TESTS @-} Baz" ,"deriveSafeCopy 1 'base ''T" ,"$(deriveSafeCopy 2 'extension ''T)" ,"test_blub test_foo = 1" ,"test_blah test_foo = 1" ,"prop_abc prop_foo = 2" ,"prop_xyz = True"] ,ModuleInfo { mi_htfPrefix = "" , mi_htfImports = [ImportDecl { imp_moduleName = "Foo" , imp_qualified = True , imp_alias = Just "Bar" , imp_loc = makeLoc "" 3} ,ImportDecl { imp_moduleName = "Foo.X" , imp_qualified = True , imp_alias = Just "Egg" , imp_loc = makeLoc "" 4} ,ImportDecl { imp_moduleName = "Foo.Y" , imp_qualified = False , imp_alias = Just "Spam" , imp_loc = makeLoc "" 5} ,ImportDecl { imp_moduleName = "Foo.Z" , imp_qualified = False , imp_alias = Nothing , imp_loc = makeLoc "" 6} ,ImportDecl { imp_moduleName = "Baz" , imp_qualified = False , imp_alias = Nothing , imp_loc = makeLoc "" 7}] , mi_moduleName = Just "FOO" , mi_defs = [TestDef "blub" (makeLoc "" 10) "test_blub" ,TestDef "blah" (makeLoc "" 11) "test_blah" ,PropDef "abc" (makeLoc "" 12) "prop_abc" ,PropDef "xyz" (makeLoc "" 13) "prop_xyz"] }) ,(unlines ["module Foo.Bar where" ,"import Test.Framework as Blub" ,"prop_xyz = True"] ,ModuleInfo { mi_htfPrefix = "" , mi_htfImports = [] , mi_moduleName = Just "Foo.Bar" , mi_defs = [PropDef "xyz" (makeLoc "" 3) "prop_xyz"] }) ,(unlines ["module Foo.Bar where" ,"import qualified Test.Framework as Blub" ,"prop_xyz = True"] ,ModuleInfo { mi_htfPrefix = "Blub." , mi_htfImports = [] , mi_moduleName = Just "Foo.Bar" , mi_defs = [PropDef "xyz" (makeLoc "" 3) "prop_xyz"] }) ,(unlines ["module Foo.Bar where" ,"import qualified Test.Framework" ,"prop_xyz = True"] ,ModuleInfo { mi_htfPrefix = "Test.Framework." , mi_htfImports = [] , mi_moduleName = Just "Foo.Bar" , mi_defs = [PropDef "xyz" (makeLoc "" 3) "prop_xyz"] })] testAnalyze = do mapM_ runTest (zip [1..] analyzeTests) where runTest (i, (src, mi)) = let givenMi = analyze "" src in if givenMi == mi then return () else assertFailure ("Error in test " ++ show i ++ ", expected:\n" ++ show mi ++ "\nGiven:\n" ++ show givenMi ++ "\nSrc:\n" ++ src) transform :: Bool -> Bool -> FilePath -> String -> IO String transform hunitBackwardsCompat debug originalFileName input = let info = analyze originalFileName fixedInput in preprocess info where preprocess :: ModuleInfo -> IO String preprocess info = do when debug $ hPutStrLn stderr ("Module info:\n" ++ show info) preProcessedInput <- runCpphs (cpphsOptions info) originalFileName fixedInput return $ preProcessedInput ++ "\n\n" ++ additionalCode info ++ "\n" -- fixedInput serves two purposes: -- 1. add a trailing \n -- 2. turn lines of the form '# ""' into line directives '#line ' -- (see http://gcc.gnu.org/onlinedocs/cpp/Preprocessor-Output.html#Preprocessor-Output). fixedInput :: String fixedInput = (unlines . map fixLine . lines) input where fixLine s = case parseCppLineInfoOut s of Just (line, fileName) -> "#line " ++ line ++ " " ++ fileName _ -> s cpphsOptions :: ModuleInfo -> CpphsOptions cpphsOptions info = defaultCpphsOptions { defines = defines defaultCpphsOptions ++ assertDefines hunitBackwardsCompat (mi_htfPrefix info) ++ nameDefines info , boolopts = (boolopts defaultCpphsOptions) { lang = True } -- lex as haskell } additionalCode :: ModuleInfo -> String additionalCode info = thisModulesTestsFullName (mi_moduleNameWithDefault info) ++ " :: " ++ mi_htfPrefix info ++ "TestSuite\n" ++ thisModulesTestsFullName (mi_moduleNameWithDefault info) ++ " = " ++ mi_htfPrefix info ++ "makeTestSuite" ++ " " ++ show (mi_moduleNameWithDefault info) ++ " [\n" ++ List.intercalate ",\n" (map (codeForDef (mi_htfPrefix info)) (mi_defs info)) ++ "\n ]\n" ++ importedTestListCode info codeForDef :: String -> Definition -> String codeForDef pref (TestDef s loc name) = locPragma loc ++ pref ++ "makeUnitTest " ++ (show s) ++ " " ++ codeForLoc pref loc ++ " " ++ name codeForDef pref (PropDef s loc name) = locPragma loc ++ pref ++ "makeQuickCheckTest " ++ (show s) ++ " " ++ codeForLoc pref loc ++ " (" ++ pref ++ "qcAssertion " ++ name ++ ")" locPragma :: Location -> String locPragma loc = "{-# LINE " ++ show (lineNumber loc) ++ " " ++ show (fileName loc) ++ " #-}\n " codeForLoc :: String -> Location -> String codeForLoc pref loc = "(" ++ pref ++ "makeLoc " ++ show (fileName loc) ++ " " ++ show (lineNumber loc) ++ ")" importedTestListCode :: ModuleInfo -> String importedTestListCode info = let l = mi_htfImports info in case l of [] -> "" _ -> (importedTestListFullName (mi_moduleNameWithDefault info) ++ " :: [" ++ mi_htfPrefix info ++ "TestSuite]\n" ++ importedTestListFullName (mi_moduleNameWithDefault info) ++ " = [\n " ++ List.intercalate ",\n " (map htfTestsInModule l) ++ "\n ]\n") htfTestsInModule :: ImportDecl -> String htfTestsInModule imp = qualify imp (thisModulesTestsFullName (imp_moduleName imp)) qualify :: ImportDecl -> String -> String qualify imp name = case (imp_qualified imp, imp_alias imp) of (False, _) -> name (True, Just alias) -> alias ++ "." ++ name (True, _) -> imp_moduleName imp ++ "." ++ name -- Returns for lines of the form '# ""' -- (see http://gcc.gnu.org/onlinedocs/cpp/Preprocessor-Output.html#Preprocessor-Output) -- the value 'Just ""' parseCppLineInfoOut :: String -> Maybe (String, String) parseCppLineInfoOut line = case line of '#':' ':c:rest | isDigit c -> case List.span isDigit rest of (restDigits, ' ' : '"' : rest) -> case dropWhile (/= '"') (reverse rest) of '"' : fileNameRev -> let line = (c:restDigits) file = "\"" ++ reverse fileNameRev ++ "\"" in Just (line, file) _ -> Nothing _ -> Nothing _ -> Nothing preprocessorTests = [("testAnalyze", testAnalyze) ,("fixPositionsTest", fixPositionsTest) ,("cleanupInputStringTest", cleanupInputStringTest)]