-- Copyright (c) Facebook, Inc. and its affiliates. -- -- This source code is licensed under the MIT license found in the -- LICENSE file in the root directory of this source tree. -- {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module Annotated (annotatedTest) where import Control.Monad.State.Lazy import Data.Data import Data.Generics import qualified Data.Map as M import qualified Data.Set as S import Data.Maybe import Test.HUnit import Retrie.ExactPrint import Retrie.GHC annotatedTest :: Test annotatedTest = TestLabel "Annotated" $ TestList [ increasingSeedTest , elemsPostGraftTest , inverseTest , uniqueSrcSpanTest , trimTest ] exprs :: [String] exprs = [ "3 + x" , "let _ = \\_ -> 4 in 7" , "case x of { y -> y }" , "let x = y in z" , unlines [ "case (3, 4) of" , " (x, 5) -> (5, x)" , " (_, y) -> (y, y)" ] , unlines [ "let f :: Int -> Int" , " f x = x + 3" , " y = 4" , "in f y + f y" ] ] types :: [String] types = [ "Int -> Int" , "Data a => a -> Int" , "(Eq a, Eq b) => (a -> b, b -> a)" ] -- Run test on all ASTs parsed from the above lists forAst :: (forall a. Data a => Annotated a -> IO ()) -> IO () forAst f = do mapM_ (parseExpr >=> f) exprs mapM_ (parseType >=> f) types -- Repeat a single transformation multiple times on an ast. The ast returned -- from the previous transformation is passed to the next transformation. testChainedTransforms :: (forall a. Data a => a -> TransformT IO a) -> IO () testChainedTransforms f = forAst $ \at -> do _ <- fmap astA $ transformA at (f >=> f >=> f >=> f) return () increasingSeedTest :: Test increasingSeedTest = TestLabel "graft increases seed" $ TestCase $ testChainedTransforms transform where transform :: Data a => a -> TransformT IO a transform = transformWithSeedIncreaseCheck . (pruneA >=> graftA) -- Following a graft, the annotation map in the state has the expected elements elemsPostGraftTest :: Test elemsPostGraftTest = TestLabel "Expected elems in map" $ TestCase $ testChainedTransforms transform where transform :: Data a => a -> TransformT IO a transform t = do annsPreGraft <- gets fst at <- pruneA t t' <- graftA at annsPostGraft <- gets fst lift $ liftIO $ do assertCountMaintained annsPreGraft t annsPostGraft assertNoOverwrite annsPreGraft annsPostGraft assertExactPrintAnns annsPreGraft annsPostGraft return t' inverseTest :: Test inverseTest = TestLabel "graftA and pruneA are inverse" $ TestCase $ testChainedTransforms transform where transform :: Data a => a -> TransformT IO a transform t = do anns <- gets fst at <- pruneA t t' <- graftA at anns' <- gets fst lift $ liftIO $ assertAstsEqual "ast pre-graft is same as ast post-graft" (anns, t) (anns', t') return t' uniqueSrcSpanTest :: Test uniqueSrcSpanTest = TestLabel "unique src span" $ TestCase $ fmap astA $ transformA (mempty :: Annotated ()) $ \() -> do ss <- transformWithSeedIncreaseCheck uniqueSrcSpanT lift $ liftIO $ assertGoodSrcSpan ss trimTest :: Test trimTest = TestLabel "trimA" $ TestCase $ forAst $ \at -> let at' = trimA at in assertLocsReplaced (astA at') transformWithSeedIncreaseCheck :: TransformT IO a -> TransformT IO a transformWithSeedIncreaseCheck m = do seed <- gets snd x <- m seed' <- gets snd lift $ liftIO $ assertBool "transform increases seed" (seed' > seed) return x -- Creates a query that returns the result of applying q to the argument (if the -- argument is a GenLocated node containing a SrcSpan), otherwise returning -- the provided default value. locatedQ :: forall b. b -> (forall a. Data a => Located a -> b) -> GenericQ b locatedQ defaultVal q = const defaultVal `ext2Q` query where query :: (Data loc, Data a) => GenLocated loc a -> b query (L l t) = case cast l :: Maybe SrcSpan of Nothing -> defaultVal Just ss -> q (L ss t) -- Structure of HsExpr AST, including constructor names and annotations -- associated with SrcSpan locations. data ConTree = ConNode AnnConName (Maybe Annotation) [ConTree] deriving (Eq, Show) -- Assert ast equality (up to src span location labeling) assertAstsEqual :: Data a => String -> (Anns, a) -> (Anns, a) -> IO () assertAstsEqual msg (anns1, t1) (anns2, t2) = assertEqual msg (conTree anns1 t1) (conTree anns2 t2) where conTree :: Data a => Anns -> a -> ConTree conTree anns = loop where loop :: Data a => a -> ConTree loop t = ConNode (annGetConstr t) (annQ t) (gmapQ loop t) annQ :: GenericQ (Maybe Annotation) annQ = locatedQ Nothing $ \loc -> M.lookup (mkAnnKey loc) anns -- Assert that all locations in the updated ast are generated by uniqueSrcSpanT assertLocsReplaced :: Data a => a -> IO () assertLocsReplaced = everything (>>) assertReplaced where assertReplaced :: GenericQ (IO ()) assertReplaced = locatedQ (return ()) $ \(L ss _) -> assertGoodSrcSpan ss -- Assert that every location in the ast has been added to the pre-graft -- annotations to form the post-graft annotations. assertCountMaintained :: Data a => Anns -> a -> Anns -> IO () assertCountMaintained annsPreGraft t annsPostGraft = let numAnnsAdded = everything (+) countIfInAnns t in assertEqual "sum of pre-graft size and # of SrcSpan sites in AST equals post-graft size" (M.size annsPreGraft + numAnnsAdded) (M.size annsPostGraft) where countIfInAnns :: GenericQ Int countIfInAnns = locatedQ 0 $ \loc -> if M.member (mkAnnKey loc) annsPreGraft then 1 else 0 -- Check that no data in pre-graft map was overwritten. assertNoOverwrite :: Anns -> Anns -> IO () assertNoOverwrite annsPreGraft annsPostGraft = assertEqual "pre-graft keys correspond to same data as post-graft" dataPreGraft dataPostGraft where dataPreGraft = M.toList annsPreGraft dataPostGraft = mapMaybe (\(k, _) -> do v <- M.lookup k annsPostGraft return (k, v)) dataPreGraft -- Assert that the annotation keys corresponding to newly-added data are of the -- expected form. assertExactPrintAnns :: Anns -> Anns -> IO () assertExactPrintAnns annsPreGraft annsPostGraft = mapM_ (\(AnnKey ss _) -> assertGoodSrcSpan ss) newKeys where newKeys :: S.Set AnnKey newKeys = M.keysSet annsPostGraft `S.difference` M.keysSet annsPreGraft assertGoodSrcSpan :: SrcSpan -> IO () assertGoodSrcSpan srcSpan = case srcSpan of RealSrcSpan rss -> do assertGoodSrcLoc (realSrcSpanStart rss) assertGoodSrcLoc (realSrcSpanEnd rss) UnhelpfulSpan _ -> assertFailure "only real src spans should be generated" assertGoodSrcLoc :: RealSrcLoc -> IO () assertGoodSrcLoc srcLoc = do let file = unpackFS $ srcLocFile srcLoc line = srcLocLine srcLoc assertEqual "srcLoc file is correctly named" "ghc-exactprint" file assertEqual "srcLoc line is correctly placed" (-1) line