module Language.Haskell.Refact.Refactoring.GenApplicative
(genApplicative, compGenApplicative) where
import Language.Haskell.Refact.API
import qualified GhcMod as GM
import qualified GhcMod.Types as GM
import qualified GHC as GHC
import qualified RdrName as GHC
import System.Directory
import FastString
import Data.Map as Map (union)
import Data.Generics as SYB
import GHC.SYB.Utils as SYB
import Data.List
import Control.Monad
import Language.Haskell.GHC.ExactPrint
import Language.Haskell.GHC.ExactPrint.Types
import Language.Haskell.GHC.ExactPrint.Print
import Language.Haskell.GHC.ExactPrint.Parsers
genApplicative :: RefactSettings -> GM.Options -> FilePath -> SimpPos -> IO [FilePath]
genApplicative settings cradle fileName pos = do
absFileName <- canonicalizePath fileName
runRefacSession settings cradle (compGenApplicative absFileName pos)
compGenApplicative :: FilePath -> SimpPos -> RefactGhc [ApplyRefacResult]
compGenApplicative fileName pos = do
(refRes@((_fp,ismod),_), ()) <- applyRefac (doGenApplicative fileName pos) (RSFile fileName)
case ismod of
RefacUnmodifed -> error "Generalise to Applicative failed"
RefacModified -> return ()
return [refRes]
doGenApplicative :: FilePath -> SimpPos -> RefactGhc ()
doGenApplicative fileName pos = do
parsed <- getRefactParsed
let funBind = case getHsBind pos parsed of
(Just fb) -> fb
Nothing -> error "That location is not the start of a function"
(retRhs, doStmts) = case (getReturnRhs funBind, getDoStmts funBind) of
((Just rets), (Just stmt)) -> (rets, stmt)
_ -> error "The function needs to consist of set of do statements with a return value."
boundVars = findBoundVars doStmts
checkPreconditions retRhs doStmts boundVars
appChain <- constructAppChain retRhs doStmts
replaceFunRhs pos appChain
checkPreconditions :: ParsedExpr -> [GHC.ExprLStmt GHC.RdrName] -> [GHC.RdrName] -> RefactGhc ()
checkPreconditions retRhs doStmts boundVars = do
let boundVarsPrecon = checkBVars doStmts boundVars
retVarsOrder = varOrdering boundVars retRhs
orderingPrecon = checkOrdering retVarsOrder doStmts
if (not boundVarsPrecon)
then error "GenApplicative Precondition: The function given uses a bound variable in a RHS expression."
else if (not orderingPrecon)
then error "GenApplicative Precondition: Variables are not bound in the order that they appear in the return statement."
else return ()
where checkBVars [] _ = True
checkBVars (stmt:stmts) vars = case stmt of
(GHC.L _ (GHC.BodyStmt body _ _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars)
#if __GLASGOW_HASKELL__ <= 710
(GHC.L _ (GHC.BindStmt _ body _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars)
#else
(GHC.L _ (GHC.BindStmt _ body _ _ _)) -> (not (lexprContainsVars vars body)) && (checkBVars stmts vars)
#endif
lexprContainsVars :: [GHC.RdrName] -> ParsedLExpr -> Bool
lexprContainsVars vars = SYB.everything (||) (False `SYB.mkQ` (\nm -> elem nm vars))
varOrdering :: [GHC.RdrName] -> ParsedExpr -> [GHC.RdrName]
varOrdering boundVars = SYB.everything (++) ([] `SYB.mkQ` (\nm -> if (elem nm boundVars) then [nm] else []))
checkOrdering :: [GHC.RdrName] -> [GHC.ExprLStmt GHC.RdrName] -> Bool
checkOrdering [] [] = True
checkOrdering [] ((GHC.L _ (GHC.BodyStmt _ _ _ _)):stmts) = checkOrdering [] stmts
checkOrdering vars ((GHC.L _ (GHC.BodyStmt _ _ _ _)):stmts) = checkOrdering vars stmts
#if __GLASGOW_HASKELL__ <= 710
checkOrdering (var:vars) ((GHC.L _ (GHC.BindStmt pat _ _ _)):stmts) =
#else
checkOrdering (var:vars) ((GHC.L _ (GHC.BindStmt pat _ _ _ _)):stmts) =
#endif
if (checkPat var pat)
then (checkOrdering vars stmts)
else False
checkPat var pat = gContains var pat
gContains :: (Data t, Eq a, Data a) => a -> t -> Bool
gContains item t = SYB.everything (||) (False `SYB.mkQ` (\b -> item == b)) t
replaceFunRhs :: SimpPos -> ParsedLExpr -> RefactGhc ()
replaceFunRhs pos newRhs = do
parsed <- getRefactParsed
let rdrNm = locToRdrName pos parsed
case rdrNm of
Nothing -> error "replaceFunRhs: Position does not correspond to a binding."
(Just (GHC.L _ rNm)) -> do
newParsed <- everywhereMStaged SYB.Parser (SYB.mkM (worker rNm)) parsed
putRefactParsed newParsed emptyAnns
logParsedSource "GenApplicative.replaceFunRhs"
where worker :: GHC.RdrName -> ParsedBind -> RefactGhc (GHC.HsBind GHC.RdrName)
#if __GLASGOW_HASKELL__ <= 710
worker rNm fBind@(GHC.FunBind (GHC.L _ fNm) _ mg _ _ _)
#else
worker rNm fBind@(GHC.FunBind (GHC.L _ fNm) mg _ _ _)
#endif
| fNm == rNm = do
newMg <- replaceMG mg
return $ fBind{GHC.fun_matches = newMg}
| otherwise = return fBind
worker _ bind = return bind
replaceMG :: ParsedMatchGroup -> RefactGhc ParsedMatchGroup
replaceMG mg = do
#if __GLASGOW_HASKELL__ <= 710
let [(GHC.L l match)] = GHC.mg_alts mg
#else
let (GHC.L _ [(GHC.L l match)]) = GHC.mg_alts mg
#endif
oldGrhss = GHC.m_grhss match
newGrhss = mkGrhss oldGrhss newRhs
newLMatch = (GHC.L l (match{GHC.m_grhss = newGrhss}))
#if __GLASGOW_HASKELL__ <= 710
return mg{GHC.mg_alts = [newLMatch]}
#else
lMatchLst <- locate [newLMatch]
return mg{GHC.mg_alts = lMatchLst}
#endif
mkGrhss old newExpr = let [(GHC.L l (GHC.GRHS lst _))] = GHC.grhssGRHSs old in
old{GHC.grhssGRHSs = [(GHC.L l (GHC.GRHS lst newExpr))]}
processReturnStatement :: ParsedExpr -> [GHC.RdrName] -> RefactGhc (Maybe ParsedLExpr)
processReturnStatement retExpr boundVars
| isJustBoundVar retExpr boundVars = return Nothing
| otherwise =
case retExpr of
(GHC.ExplicitTuple lst _) -> do
dFlags <- GHC.getSessionDynFlags
let commas = repeat ','
constr = "(" ++ (take ((length lst)1) commas) ++ ")"
parseRes = parseExpr dFlags "hare" constr
case parseRes of
(Left (_, errMsg)) -> do
logm "processReturnStatement: error parsing tuple constructor"
return Nothing
(Right (anns, expr)) -> do
mergeRefactAnns anns
return (Just expr)
_ -> do
lRet <- locate retExpr
stripBoundVars lRet boundVars
where stripBoundVars :: ParsedLExpr -> [GHC.RdrName] -> RefactGhc (Maybe ParsedLExpr)
#if __GLASGOW_HASKELL__ <= 710
stripBoundVars le@(GHC.L l (GHC.HsVar nm)) names =
#else
stripBoundVars le@(GHC.L l (GHC.HsVar (GHC.L _ nm))) names =
#endif
if (elem nm names)
then return Nothing
else return (Just le)
stripBoundVars (GHC.L l (GHC.HsApp expr1 expr2)) names = do
ne1 <- stripBoundVars expr1 names
ne2 <- stripBoundVars expr2 names
case ne2 of
Nothing -> return ne1
(Just e2) -> return (ne1 >>= (\e1 -> Just (GHC.L l (GHC.HsApp e1 e2))))
isJustBoundVar :: ParsedExpr -> [GHC.RdrName] -> Bool
#if __GLASGOW_HASKELL__ <= 710
isJustBoundVar (GHC.HsVar nm) names = elem nm names
#else
isJustBoundVar (GHC.HsVar (GHC.L _ nm)) names = elem nm names
#endif
isJustBoundVar _ _ = False
getDoStmts :: GHC.HsBind GHC.RdrName -> Maybe [GHC.ExprLStmt GHC.RdrName]
getDoStmts funBind = SYB.something (Nothing `SYB.mkQ` stmtLst) funBind
where stmtLst :: GHC.HsExpr GHC.RdrName -> Maybe [GHC.ExprLStmt GHC.RdrName]
#if __GLASGOW_HASKELL__ <= 710
stmtLst (GHC.HsDo _ stmtLst _) = Just (init stmtLst)
#else
stmtLst (GHC.HsDo _ (GHC.L _ stmtLst) _) = Just (init stmtLst)
#endif
stmtLst _ = Nothing
findBoundVars :: [GHC.ExprLStmt GHC.RdrName] -> [GHC.RdrName]
findBoundVars = SYB.everything (++) ([] `SYB.mkQ` findVarPats)
where findVarPats :: GHC.Pat GHC.RdrName -> [GHC.RdrName]
#if __GLASGOW_HASKELL__ <= 710
findVarPats (GHC.VarPat rdr) = [rdr]
#else
findVarPats (GHC.VarPat (GHC.L _ rdr)) = [rdr]
#endif
findVarPats _ = []
getReturnRhs :: UnlocParsedHsBind -> Maybe ParsedExpr
getReturnRhs funBind = SYB.something (Nothing `SYB.mkQ` retStmt `SYB.extQ` dollarRet) funBind
where retStmt :: GHC.ExprLStmt GHC.RdrName -> Maybe ParsedExpr
retStmt (GHC.L _ (GHC.BodyStmt (GHC.L _ body) _ _ _)) = if isRet body
then Just (retRHS body)
else Nothing
retStmt _ = Nothing
dollarRet :: ParsedExpr -> Maybe ParsedExpr
dollarRet (GHC.OpApp ret dollar _ expr) = if (isHsVar "return" $ GHC.unLoc ret) && (isHsVar "$" $ GHC.unLoc dollar)
then Just (GHC.unLoc expr)
else Nothing
dollarRet _ = Nothing
isRet :: ParsedExpr -> Bool
isRet (GHC.HsApp (GHC.L _ mRet) _) = isHsVar "return" mRet
isRet _ = False
retRHS :: ParsedExpr -> ParsedExpr
retRHS (GHC.HsApp _ (GHC.L _ rhs)) = rhs
constructAppChain :: ParsedExpr -> [GHC.ExprLStmt GHC.RdrName] -> RefactGhc ParsedLExpr
constructAppChain retRhs lst = do
let clusters = clusterStmts lst
boundVars = findBoundVars lst
pars <- mapM buildSingleExpr clusters
pars2 <- if length pars == 1
then do newP <- (removePars (head pars))
return [newP]
else return pars
effects <- buildChain pars2
mPure <- processReturnStatement retRhs boundVars
case mPure of
Nothing -> do
return effects
(Just pure) -> do
setDP (DP (0,1)) pure
lOp <- lInfixFmap
addAnnVal lOp
locate (GHC.OpApp pure lOp GHC.PlaceHolder effects)
where
buildChain :: [ParsedLExpr] -> RefactGhc ParsedLExpr
buildChain [e] = return e
buildChain (e:es) = do
rhs <- buildChain es
lOp <- lFApp
addAnnVal lOp
let opApp = (GHC.OpApp e lOp GHC.PlaceHolder rhs)
locate opApp
getStmtExpr :: GHC.ExprLStmt GHC.RdrName -> ParsedLExpr
getStmtExpr (GHC.L _ (GHC.BodyStmt body _ _ _)) = body
#if __GLASGOW_HASKELL__ <= 710
getStmtExpr (GHC.L _ (GHC.BindStmt _ body _ _)) = body
#else
getStmtExpr (GHC.L _ (GHC.BindStmt _ body _ _ _)) = body
#endif
buildSingleExpr :: [GHC.ExprLStmt GHC.RdrName] -> RefactGhc ParsedLExpr
buildSingleExpr [st] = return $ getStmtExpr st
buildSingleExpr lst@(st:stmts) = do
let (before,(bindSt:after)) = break isBindStmt lst
rOp <- rApp
lOp <- lApp
mLeftOfBnds <- buildApps rOp (map getStmtExpr before)
mRightOfBnds <- buildApps lOp (map getStmtExpr after)
mapM_ (\ex -> (setDP (DP (0,1))) (getStmtExpr ex)) (tail lst)
lROp <- lRApp
addAnnVal lROp
lLOp <- lLApp
addAnnVal lLOp
newBndStmt <- mkBind (getStmtExpr bindSt)
case (mLeftOfBnds,mRightOfBnds) of
(Nothing,Nothing) -> error "buildSingleExpr was passed an empty list."
((Just leftOfBnds),Nothing) -> do
app <- locate (GHC.OpApp leftOfBnds lROp GHC.PlaceHolder newBndStmt)
wrapInPars app
(Nothing, (Just rightOfBnds)) -> do
app <- locate (GHC.OpApp newBndStmt lLOp GHC.PlaceHolder rightOfBnds)
wrapInPars app
((Just leftOfBnds),(Just rightOfBnds)) -> do
setDP (DP (0,1)) newBndStmt
lOpApp <- locate (GHC.OpApp leftOfBnds lROp GHC.PlaceHolder newBndStmt)
fullApp <- locate (GHC.OpApp lOpApp lLOp GHC.PlaceHolder rightOfBnds)
wrapInPars fullApp
mkBind :: ParsedLExpr -> RefactGhc ParsedLExpr
mkBind e@(GHC.L _ (GHC.HsVar _)) = return e
mkBind expr = do
zeroDP expr
wrapInParsWithDPs (DP (0,0)) (DP (0,0)) expr
buildApps :: ParsedExpr -> [ParsedLExpr] -> RefactGhc (Maybe ParsedLExpr)
buildApps op [] = return Nothing
buildApps op [st] = return (Just st)
buildApps op (st:stmts) = do
mRhs <- buildApps op stmts
case mRhs of
Nothing -> return (Just st)
(Just rhs) -> do
lOp <- locate op
addAnnVal lOp
lExpr <- locate (GHC.OpApp st lOp GHC.PlaceHolder rhs)
return (Just lExpr)
clusterStmts :: [GHC.ExprLStmt GHC.RdrName] -> [[GHC.ExprLStmt GHC.RdrName]]
clusterStmts lst = let indices = findIndices isBindStmt lst
clusters = cluster indices (length lst) 0 in
map (\is -> map (\i -> lst !! i) is) clusters
cluster [i] l c = [[c..(l1)]]
cluster (i1:i2:is) l c = let b = i1 + ((i2i1) `div` 2) in
[c .. b]:(cluster (i2:is) l (b+1))
nameOccurs :: Data a => GHC.RdrName -> a -> Bool
nameOccurs nm = SYB.everything (||) (False `SYB.mkQ` isName)
where isName :: GHC.RdrName -> Bool
isName mName = nm == mName
isBindStmt :: GHC.ExprLStmt GHC.RdrName -> Bool
#if __GLASGOW_HASKELL__ <= 710
isBindStmt (GHC.L _ (GHC.BindStmt _ _ _ _)) = True
#else
isBindStmt (GHC.L _ (GHC.BindStmt _ _ _ _ _)) = True
#endif
isBindStmt _ = False
lFApp :: RefactGhc ParsedLExpr
lFApp = fApp >>= locate
fApp :: RefactGhc ParsedExpr
fApp = hsVar "<*>"
isFApp :: ParsedLExpr -> Bool
#if __GLASGOW_HASKELL__ <= 710
isFApp (GHC.L _ (GHC.HsVar rdrNm)) = (GHC.mkVarUnqual (fsLit "<*>")) == rdrNm
#else
isFApp (GHC.L _ (GHC.HsVar (GHC.L _ rdrNm))) = (GHC.mkVarUnqual (fsLit "<*>")) == rdrNm
#endif
isFApp _ = False
lLApp :: RefactGhc ParsedLExpr
lLApp = lApp >>= locate
lApp :: RefactGhc ParsedExpr
lApp = hsVar "<*"
lRApp :: RefactGhc ParsedLExpr
lRApp = rApp >>= locate
rApp :: RefactGhc ParsedExpr
rApp = hsVar "*>"
lInfixFmap :: RefactGhc ParsedLExpr
lInfixFmap = infixFmap >>= locate
infixFmap :: RefactGhc ParsedExpr
infixFmap = hsVar "<$>"
hsVar :: String -> RefactGhc ParsedExpr
hsVar n = do
#if __GLASGOW_HASKELL__ <= 710
return (GHC.HsVar (mkRdrName n))
#else
lNm <- locate $ mkRdrName n
liftT $ addSimpleAnnT lNm (DP (0,0)) [(G GHC.AnnVal,DP (0,0))]
return (GHC.HsVar lNm)
#endif