{-# LANGUAGE ViewPatterns , ScopedTypeVariables , RankNTypes , FlexibleContexts , TypeApplications , ConstraintKinds , TypeFamilies #-} module Language.Haskell.Tools.Refactor.ExtractBinding where import qualified GHC import qualified Var as GHC import qualified OccName as GHC hiding (varName) import SrcLoc import Unique import Data.Char import Data.Maybe import Data.Generics.Uniplate.Data import Control.Reference hiding (element) import Control.Monad.State import Language.Haskell.Tools.AST import Language.Haskell.Tools.AnnTrf.SourceTemplate import Language.Haskell.Tools.AST.Gen import Language.Haskell.Tools.Refactor.RefactorBase import Language.Haskell.Tools.AnnTrf.SourceTemplateHelpers import Debug.Trace type Ann' e dom = Ann e dom SrcTemplateStage type AnnMaybe' e dom = AnnMaybe e dom SrcTemplateStage type ExtractBindingDomain dom = ( Domain dom, HasNameInfo (SemanticInfo' dom SameInfoNameCls), HasDefiningInfo (SemanticInfo' dom SameInfoNameCls) , HasScopeInfo (SemanticInfo' dom SameInfoExprCls) ) extractBinding' :: ExtractBindingDomain dom => RealSrcSpan -> String -> Ann' Module dom -> RefactoredModule dom extractBinding' sp name mod = if isValidBindingName name then extractBinding (nodesContaining sp) (nodesContaining sp) name mod else refactError "The given name is not a valid for the extracted binding" extractBinding :: forall dom . ExtractBindingDomain dom => Simple Traversal (Ann' Module dom) (Ann' ValueBind dom) -> Simple Traversal (Ann' ValueBind dom) (Ann' Expr dom) -> String -> Ann' Module dom -> RefactoredModule dom extractBinding selectDecl selectExpr name mod = let conflicting = any (isConflicting name) (mod ^? selectDecl & biplateRef :: [Ann' SimpleName dom]) exprRange = getRange $ head (mod ^? selectDecl & selectExpr & annotation & sourceInfo) decl = last (mod ^? selectDecl) declRange = getRange $ last (mod ^? selectDecl & annotation & sourceInfo) in if conflicting then refactError "The given name causes name conflict." else do (res, st) <- runStateT (selectDecl&selectExpr !~ extractThatBind name (head $ decl ^? actualContainingExpr exprRange) $ mod) Nothing case st of Just def -> return $ evalState (selectDecl&element !~ addLocalBinding declRange exprRange def $ res) False Nothing -> refactError "There is no applicable expression to extract." isConflicting :: ExtractBindingDomain dom => String -> Ann' SimpleName dom -> Bool isConflicting name used = semanticsDefining (used ^. semantics) && (GHC.occNameString . GHC.getOccName <$> semanticsName (used ^. semantics)) == Just name -- Replaces the selected expression with a call and generates the called binding. extractThatBind :: ExtractBindingDomain dom => String -> Ann' Expr dom -> Ann' Expr dom -> StateT (Maybe (Ann' ValueBind dom)) (Refactor dom) (Ann' Expr dom) extractThatBind name cont e = do ret <- get if (isJust ret) then return e else case (e ^. element) of Paren {} | hasParameter -> element & exprInner !~ doExtract name cont $ e | otherwise -> doExtract name cont (fromJust $ e ^? element & exprInner) Var {} -> lift $ refactError "The selected expression is too simple to be extracted." el | isParenLikeExpr el && hasParameter -> mkParen <$> doExtract name cont e el -> doExtract name cont e where hasParameter = not (null (getExternalBinds cont e)) addLocalBinding :: SrcSpan -> SrcSpan -> Ann' ValueBind dom -> ValueBind dom SrcTemplateStage -> State Bool (ValueBind dom SrcTemplateStage) addLocalBinding declRange exprRange local bind = do done <- get if not done then do put True return $ doAddBinding declRange exprRange local bind else return bind where doAddBinding declRng _ local sb@(SimpleBind {}) = valBindLocals .- insertLocalBind declRng local $ sb doAddBinding declRng (RealSrcSpan rng) local fb@(FunBind {}) = funBindMatches & annList & filtered (isInside rng) & element & matchBinds .- insertLocalBind declRng local $ fb insertLocalBind :: SrcSpan -> Ann' ValueBind dom -> AnnMaybe' LocalBinds dom -> AnnMaybe' LocalBinds dom insertLocalBind declRng toInsert locals | isAnnNothing locals , RealSrcSpan rng <- declRng = -- creates the new where clause indented 2 spaces from the declaration mkLocalBinds (srcLocCol (realSrcSpanStart rng) + 2) [mkLocalValBind toInsert] | otherwise = annJust & element & localBinds .- insertWhere (mkLocalValBind toInsert) (const True) isNothing $ locals -- | All expressions that are bound stronger than function application. isParenLikeExpr :: Expr dom st -> Bool isParenLikeExpr (If {}) = True isParenLikeExpr (Paren {}) = True isParenLikeExpr (List {}) = True isParenLikeExpr (ParArray {}) = True isParenLikeExpr (LeftSection {}) = True isParenLikeExpr (RightSection {}) = True isParenLikeExpr (RecCon {}) = True isParenLikeExpr (RecUpdate {}) = True isParenLikeExpr (Enum {}) = True isParenLikeExpr (ParArrayEnum {}) = True isParenLikeExpr (ListComp {}) = True isParenLikeExpr (ParArrayComp {}) = True isParenLikeExpr (BracketExpr {}) = True isParenLikeExpr (Splice {}) = True isParenLikeExpr (QuasiQuoteExpr {}) = True isParenLikeExpr _ = False doExtract :: ExtractBindingDomain dom => String -> Ann' Expr dom -> Ann' Expr dom -> StateT (Maybe (Ann' ValueBind dom)) (Refactor dom) (Ann' Expr dom) doExtract name cont e@((^. element) -> lam@(Lambda {})) = do let params = getExternalBinds cont e put (Just (generateBind name (map mkVarPat params ++ (lam ^? exprBindings&annList)) (fromJust $ lam ^? exprInner))) return (generateCall name params) doExtract name cont e = do let params = getExternalBinds cont e put (Just (generateBind name (map mkVarPat params) e)) return (generateCall name params) -- | Gets the values that have to be passed to the extracted definition getExternalBinds :: ExtractBindingDomain dom => Ann' Expr dom -> Ann' Expr dom -> [Ann' Name dom] getExternalBinds cont expr = map exprToName $ keepFirsts $ filter isApplicableName (expr ^? uniplateRef) where isApplicableName name@(getExprNameInfo -> Just nm) = inScopeForOriginal nm && notInScopeForExtracted nm isApplicableName _ = False getExprNameInfo :: ExtractBindingDomain dom => Ann' Expr dom -> Maybe GHC.Name getExprNameInfo expr = semanticsName =<< (listToMaybe $ expr ^? element & (exprName&element&simpleName &+& exprOperator&element&operatorName) & semantics) -- | Creates the parameter value to pass the name (operators are passed in parentheses) exprToName :: Ann' Expr dom -> Ann' Name dom exprToName e | Just n <- e ^? element & exprName = n | Just op <- e ^? element & exprOperator & element & operatorName = mkParenName op notInScopeForExtracted :: GHC.Name -> Bool notInScopeForExtracted n = notElem @[] n (semanticsScope (cont ^. semantics) ^? traversal & traversal) inScopeForOriginal :: GHC.Name -> Bool inScopeForOriginal n = elem @[] n (semanticsScope (expr ^. semantics) ^? traversal & traversal) keepFirsts (e:rest) = e : keepFirsts (filter (/= e) rest) keepFirsts [] = [] actualContainingExpr :: SourceInfo st => SrcSpan -> Simple Traversal (Ann ValueBind dom st) (Ann Expr dom st) actualContainingExpr (RealSrcSpan rng) = element & accessRhs & element & accessExpr where accessRhs :: SourceInfo st => Simple Traversal (ValueBind dom st) (Ann Rhs dom st) accessRhs = valBindRhs &+& funBindMatches & annList & filtered (isInside rng) & element & matchRhs accessExpr :: SourceInfo st => Simple Traversal (Rhs dom st) (Ann Expr dom st) accessExpr = rhsExpr &+& rhsGuards & annList & filtered (isInside rng) & element & guardExpr -- | Generates the expression that calls the local binding generateCall :: String -> [Ann' Name dom] -> Ann' Expr dom generateCall name args = foldl (\e a -> mkApp e (mkVar a)) (mkVar $ mkNormalName $ mkSimpleName name) args -- | Generates the local binding for the selected expression generateBind :: String -> [Ann' Pattern dom] -> Ann' Expr dom -> Ann' ValueBind dom generateBind name [] e = mkSimpleBind (mkVarPat $ mkNormalName $ mkSimpleName name) (mkUnguardedRhs e) Nothing generateBind name args e = mkFunctionBind [mkMatch (mkNormalMatchLhs (mkNormalName $ mkSimpleName name) args) (mkUnguardedRhs e) Nothing] isValidBindingName :: String -> Bool isValidBindingName = nameValid Variable