{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} module Language.Haskell.Refact.Refactoring.GenApplicative (genApplicative, compGenApplicative) where import Language.Haskell.Refact.API import qualified GhcMod as GM import qualified GhcMod.Types as GM import qualified GHC as GHC import qualified RdrName as GHC import System.Directory import FastString import Data.Map as Map (union) import Data.Generics as SYB import GHC.SYB.Utils as SYB import Data.List import Control.Monad import Language.Haskell.GHC.ExactPrint import Language.Haskell.GHC.ExactPrint.Types import Language.Haskell.GHC.ExactPrint.Print import Language.Haskell.GHC.ExactPrint.Parsers genApplicative :: RefactSettings -> GM.Options -> FilePath -> SimpPos -> IO [FilePath] genApplicative settings cradle fileName pos = do absFileName <- canonicalizePath fileName runRefacSession settings cradle (compGenApplicative absFileName pos) compGenApplicative :: FilePath -> SimpPos -> RefactGhc [ApplyRefacResult] compGenApplicative fileName pos = do (refRes@((_fp,ismod),_), ()) <- applyRefac (doGenApplicative fileName pos) (RSFile fileName) case ismod of RefacUnmodifed -> error "Generalise to Applicative failed" RefacModified -> return () return [refRes] --The general operation of this refactoring involves transposing the LHsExpr's from within the list of ExprLStmt's -- to an (OpApp :: HsExpr) -- The function begins by constructing the beginning of the applicative chain by looking at the construction of the return statement doGenApplicative :: FilePath -> SimpPos -> RefactGhc () doGenApplicative fileName pos = do parsed <- getRefactParsed let funBind = case getHsBind pos parsed of (Just fb) -> fb Nothing -> error "That location is not the start of a function" (retRhs, doStmts) = case (getReturnRhs funBind, getDoStmts funBind) of ((Just rets), (Just stmt)) -> (rets, stmt) _ -> error "The function needs to consist of set of do statements with a return value." boundVars = findBoundVars doStmts checkPreconditions retRhs doStmts boundVars appChain <- constructAppChain retRhs doStmts replaceFunRhs pos appChain checkPreconditions :: ParsedExpr -> [GHC.ExprLStmt GHC.RdrName] -> [GHC.RdrName] -> RefactGhc () checkPreconditions retRhs doStmts boundVars = do let boundVarsPrecon = checkBVars doStmts boundVars retVarsOrder = varOrdering boundVars retRhs orderingPrecon = checkOrdering retVarsOrder doStmts if (not boundVarsPrecon) then error "GenApplicative Precondition: The function given uses a bound variable in a RHS expression." else if (not orderingPrecon) then error "GenApplicative Precondition: Variables are not bound in the order that they appear in the return statement." else return () where checkBVars [] _ = True checkBVars (stmt:stmts) vars = case stmt of (GHC.L _ (GHC.BodyStmt body _ _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars) #if __GLASGOW_HASKELL__ <= 710 (GHC.L _ (GHC.BindStmt _ body _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars) #else (GHC.L _ (GHC.BindStmt _ body _ _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars) #endif lexprContainsVars :: [GHC.RdrName] -> ParsedLExpr -> Bool lexprContainsVars vars = SYB.everything (||) (False `SYB.mkQ` (\nm -> elem nm vars)) varOrdering :: [GHC.RdrName] -> ParsedExpr -> [GHC.RdrName] varOrdering boundVars = SYB.everything (++) ([] `SYB.mkQ` (\nm -> if (elem nm boundVars) then [nm] else [])) checkOrdering :: [GHC.RdrName] -> [GHC.ExprLStmt GHC.RdrName] -> Bool checkOrdering [] [] = True checkOrdering [] ((GHC.L _ (GHC.BodyStmt _ _ _ _)):stmts) = checkOrdering [] stmts checkOrdering vars ((GHC.L _ (GHC.BodyStmt _ _ _ _)):stmts) = checkOrdering vars stmts #if __GLASGOW_HASKELL__ <= 710 checkOrdering (var:vars) ((GHC.L _ (GHC.BindStmt pat _ _ _)):stmts) = #else checkOrdering (var:vars) ((GHC.L _ (GHC.BindStmt pat _ _ _ _)):stmts) = #endif if (checkPat var pat) then (checkOrdering vars stmts) else False checkPat var pat = gContains var pat gContains :: (Data t, Eq a, Data a) => a -> t -> Bool gContains item t = SYB.everything (||) (False `SYB.mkQ` (\b -> item == b)) t replaceFunRhs :: SimpPos -> ParsedLExpr -> RefactGhc () replaceFunRhs pos newRhs = do parsed <- getRefactParsed let rdrNm = locToRdrName pos parsed case rdrNm of Nothing -> error "replaceFunRhs: Position does not correspond to a binding." (Just (GHC.L _ rNm)) -> do newParsed <- everywhereMStaged SYB.Parser (SYB.mkM (worker rNm)) parsed putRefactParsed newParsed emptyAnns logParsedSource "GenApplicative.replaceFunRhs" where worker :: GHC.RdrName -> ParsedBind -> RefactGhc (GHC.HsBind GHC.RdrName) #if __GLASGOW_HASKELL__ <= 710 worker rNm fBind@(GHC.FunBind (GHC.L _ fNm) _ mg _ _ _) #else worker rNm fBind@(GHC.FunBind (GHC.L _ fNm) mg _ _ _) #endif | fNm == rNm = do newMg <- replaceMG mg return $ fBind{GHC.fun_matches = newMg} | otherwise = return fBind worker _ bind = return bind replaceMG :: ParsedMatchGroup -> RefactGhc ParsedMatchGroup replaceMG mg = do #if __GLASGOW_HASKELL__ <= 710 let [(GHC.L l match)] = GHC.mg_alts mg #else let (GHC.L _ [(GHC.L l match)]) = GHC.mg_alts mg #endif oldGrhss = GHC.m_grhss match newGrhss = mkGrhss oldGrhss newRhs newLMatch = (GHC.L l (match{GHC.m_grhss = newGrhss})) #if __GLASGOW_HASKELL__ <= 710 return mg{GHC.mg_alts = [newLMatch]} #else lMatchLst <- locate [newLMatch] return mg{GHC.mg_alts = lMatchLst} #endif mkGrhss old newExpr = let [(GHC.L l (GHC.GRHS lst _))] = GHC.grhssGRHSs old in old{GHC.grhssGRHSs = [(GHC.L l (GHC.GRHS lst newExpr))]} processReturnStatement :: ParsedExpr -> [GHC.RdrName] -> RefactGhc (Maybe ParsedLExpr) processReturnStatement retExpr boundVars | isJustBoundVar retExpr boundVars = return Nothing | otherwise = case retExpr of (GHC.ExplicitTuple lst _) -> do dFlags <- GHC.getSessionDynFlags let commas = repeat ',' constr = "(" ++ (take ((length lst)-1) commas) ++ ")" parseRes = parseExpr dFlags "hare" constr case parseRes of (Left (_, errMsg)) -> do logm "processReturnStatement: error parsing tuple constructor" return Nothing (Right (anns, expr)) -> do mergeRefactAnns anns return (Just expr) _ -> do lRet <- locate retExpr stripBoundVars lRet boundVars where stripBoundVars :: ParsedLExpr -> [GHC.RdrName] -> RefactGhc (Maybe ParsedLExpr) #if __GLASGOW_HASKELL__ <= 710 stripBoundVars le@(GHC.L l (GHC.HsVar nm)) names = #else stripBoundVars le@(GHC.L l (GHC.HsVar (GHC.L _ nm))) names = #endif if (elem nm names) then return Nothing else return (Just le) stripBoundVars (GHC.L l (GHC.HsApp expr1 expr2)) names = do ne1 <- stripBoundVars expr1 names ne2 <- stripBoundVars expr2 names case ne2 of Nothing -> return ne1 (Just e2) -> return (ne1 >>= (\e1 -> Just (GHC.L l (GHC.HsApp e1 e2)))) isJustBoundVar :: ParsedExpr -> [GHC.RdrName] -> Bool #if __GLASGOW_HASKELL__ <= 710 isJustBoundVar (GHC.HsVar nm) names = elem nm names #else isJustBoundVar (GHC.HsVar (GHC.L _ nm)) names = elem nm names #endif isJustBoundVar _ _ = False getDoStmts :: GHC.HsBind GHC.RdrName -> Maybe [GHC.ExprLStmt GHC.RdrName] getDoStmts funBind = SYB.something (Nothing `SYB.mkQ` stmtLst) funBind where stmtLst :: GHC.HsExpr GHC.RdrName -> Maybe [GHC.ExprLStmt GHC.RdrName] #if __GLASGOW_HASKELL__ <= 710 stmtLst (GHC.HsDo _ stmtLst _) = Just (init stmtLst) #else stmtLst (GHC.HsDo _ (GHC.L _ stmtLst) _) = Just (init stmtLst) #endif stmtLst _ = Nothing findBoundVars :: [GHC.ExprLStmt GHC.RdrName] -> [GHC.RdrName] findBoundVars = SYB.everything (++) ([] `SYB.mkQ` findVarPats) where findVarPats :: GHC.Pat GHC.RdrName -> [GHC.RdrName] #if __GLASGOW_HASKELL__ <= 710 findVarPats (GHC.VarPat rdr) = [rdr] #else findVarPats (GHC.VarPat (GHC.L _ rdr)) = [rdr] #endif findVarPats _ = [] getReturnRhs :: UnlocParsedHsBind -> Maybe ParsedExpr getReturnRhs funBind = SYB.something (Nothing `SYB.mkQ` retStmt `SYB.extQ` dollarRet) funBind where retStmt :: GHC.ExprLStmt GHC.RdrName -> Maybe ParsedExpr retStmt (GHC.L _ (GHC.BodyStmt (GHC.L _ body) _ _ _)) = if isRet body then Just (retRHS body) else Nothing retStmt _ = Nothing dollarRet :: ParsedExpr -> Maybe ParsedExpr dollarRet (GHC.OpApp ret dollar _ expr) = if (isHsVar "return" $ GHC.unLoc ret) && (isHsVar "$" $ GHC.unLoc dollar) then Just (GHC.unLoc expr) else Nothing dollarRet _ = Nothing isRet :: ParsedExpr -> Bool isRet (GHC.HsApp (GHC.L _ mRet) _) = isHsVar "return" mRet isRet _ = False retRHS :: ParsedExpr -> ParsedExpr retRHS (GHC.HsApp _ (GHC.L _ rhs)) = rhs constructAppChain :: ParsedExpr -> [GHC.ExprLStmt GHC.RdrName] -> RefactGhc ParsedLExpr constructAppChain retRhs lst = do let clusters = clusterStmts lst boundVars = findBoundVars lst pars <- mapM buildSingleExpr clusters pars2 <- if length pars == 1 then do newP <- (removePars (head pars)) return [newP] else return pars effects <- buildChain pars2 mPure <- processReturnStatement retRhs boundVars case mPure of Nothing -> do return effects (Just pure) -> do setDP (DP (0,1)) pure lOp <- lInfixFmap addAnnVal lOp locate (GHC.OpApp pure lOp GHC.PlaceHolder effects) where buildChain :: [ParsedLExpr] -> RefactGhc ParsedLExpr buildChain [e] = return e buildChain (e:es) = do rhs <- buildChain es lOp <- lFApp addAnnVal lOp let opApp = (GHC.OpApp e lOp GHC.PlaceHolder rhs) locate opApp getStmtExpr :: GHC.ExprLStmt GHC.RdrName -> ParsedLExpr getStmtExpr (GHC.L _ (GHC.BodyStmt body _ _ _)) = body #if __GLASGOW_HASKELL__ <= 710 getStmtExpr (GHC.L _ (GHC.BindStmt _ body _ _)) = body #else getStmtExpr (GHC.L _ (GHC.BindStmt _ body _ _ _)) = body #endif buildSingleExpr :: [GHC.ExprLStmt GHC.RdrName] -> RefactGhc ParsedLExpr buildSingleExpr [st] = return $ getStmtExpr st buildSingleExpr lst@(st:stmts) = do let (before,(bindSt:after)) = break isBindStmt lst rOp <- rApp lOp <- lApp mLeftOfBnds <- buildApps rOp (map getStmtExpr before) mRightOfBnds <- buildApps lOp (map getStmtExpr after) mapM_ (\ex -> (setDP (DP (0,1))) (getStmtExpr ex)) (tail lst) lROp <- lRApp addAnnVal lROp lLOp <- lLApp addAnnVal lLOp newBndStmt <- mkBind (getStmtExpr bindSt) case (mLeftOfBnds,mRightOfBnds) of (Nothing,Nothing) -> error "buildSingleExpr was passed an empty list." ((Just leftOfBnds),Nothing) -> do app <- locate (GHC.OpApp leftOfBnds lROp GHC.PlaceHolder newBndStmt) wrapInPars app (Nothing, (Just rightOfBnds)) -> do app <- locate (GHC.OpApp newBndStmt lLOp GHC.PlaceHolder rightOfBnds) wrapInPars app ((Just leftOfBnds),(Just rightOfBnds)) -> do setDP (DP (0,1)) newBndStmt lOpApp <- locate (GHC.OpApp leftOfBnds lROp GHC.PlaceHolder newBndStmt) fullApp <- locate (GHC.OpApp lOpApp lLOp GHC.PlaceHolder rightOfBnds) wrapInPars fullApp mkBind :: ParsedLExpr -> RefactGhc ParsedLExpr mkBind e@(GHC.L _ (GHC.HsVar _)) = return e mkBind expr = do zeroDP expr wrapInParsWithDPs (DP (0,0)) (DP (0,0)) expr buildApps :: ParsedExpr -> [ParsedLExpr] -> RefactGhc (Maybe ParsedLExpr) buildApps op [] = return Nothing buildApps op [st] = return (Just st) buildApps op (st:stmts) = do mRhs <- buildApps op stmts case mRhs of Nothing -> return (Just st) (Just rhs) -> do lOp <- locate op addAnnVal lOp lExpr <- locate (GHC.OpApp st lOp GHC.PlaceHolder rhs) return (Just lExpr) clusterStmts :: [GHC.ExprLStmt GHC.RdrName] -> [[GHC.ExprLStmt GHC.RdrName]] clusterStmts lst = let indices = findIndices isBindStmt lst clusters = cluster indices (length lst) 0 in map (\is -> map (\i -> lst !! i) is) clusters cluster [i] l c = [[c..(l-1)]] cluster (i1:i2:is) l c = let b = i1 + ((i2-i1) `div` 2) in [c .. b]:(cluster (i2:is) l (b+1)) --Checks if a name occurs in the given ast chunk nameOccurs :: Data a => GHC.RdrName -> a -> Bool nameOccurs nm = SYB.everything (||) (False `SYB.mkQ` isName) where isName :: GHC.RdrName -> Bool isName mName = nm == mName isBindStmt :: GHC.ExprLStmt GHC.RdrName -> Bool #if __GLASGOW_HASKELL__ <= 710 isBindStmt (GHC.L _ (GHC.BindStmt _ _ _ _)) = True #else isBindStmt (GHC.L _ (GHC.BindStmt _ _ _ _ _)) = True #endif isBindStmt _ = False lFApp :: RefactGhc ParsedLExpr lFApp = fApp >>= locate fApp :: RefactGhc ParsedExpr fApp = hsVar "<*>" isFApp :: ParsedLExpr -> Bool #if __GLASGOW_HASKELL__ <= 710 isFApp (GHC.L _ (GHC.HsVar rdrNm)) = (GHC.mkVarUnqual (fsLit "<*>")) == rdrNm #else isFApp (GHC.L _ (GHC.HsVar (GHC.L _ rdrNm))) = (GHC.mkVarUnqual (fsLit "<*>")) == rdrNm #endif isFApp _ = False lLApp :: RefactGhc ParsedLExpr lLApp = lApp >>= locate lApp :: RefactGhc ParsedExpr lApp = hsVar "<*" lRApp :: RefactGhc ParsedLExpr lRApp = rApp >>= locate rApp :: RefactGhc ParsedExpr rApp = hsVar "*>" lInfixFmap :: RefactGhc ParsedLExpr lInfixFmap = infixFmap >>= locate infixFmap :: RefactGhc ParsedExpr infixFmap = hsVar "<$>" -- TODO: Move this to Utils/Variables.hs, but make it lhsVar hsVar :: String -> RefactGhc ParsedExpr hsVar n = do #if __GLASGOW_HASKELL__ <= 710 return (GHC.HsVar (mkRdrName n)) #else lNm <- locate $ mkRdrName n liftT $ addSimpleAnnT lNm (DP (0,0)) [(G GHC.AnnVal,DP (0,0))] return (GHC.HsVar lNm) #endif