module Idris.WhoCalls (whoCalls, callsWho) where import Idris.AbsSyntax import Idris.Core.CaseTree import Idris.Core.Evaluate import Idris.Core.TT import Data.List (nub) occurs :: Name -> Term -> Bool occurs n (P Bound _ _) = False occurs n (P _ n' _) = n == n' occurs n (Bind _ b sc) = occursBinder n b || occurs n sc occurs n (App t1 t2) = occurs n t1 || occurs n t2 occurs n (Proj t _) = occurs n t occurs n _ = False names :: Term -> [Name] names (P Bound _ _) = [] names (P _ n _) = [n] names (Bind _ b sc) = namesBinder b ++ names sc names (App t1 t2) = names t1 ++ names t2 names (Proj t _) = names t names _ = [] occursBinder :: Name -> Binder Term -> Bool occursBinder n (Let ty val) = occurs n ty || occurs n val occursBinder n (NLet ty val) = occurs n ty || occurs n val occursBinder n b = occurs n (binderTy b) namesBinder :: Binder Term -> [Name] namesBinder (Let ty val) = names ty ++ names val namesBinder (NLet ty val) = names ty ++ names val namesBinder b = names (binderTy b) occursSC :: Name -> SC -> Bool occursSC n (Case _ alts) = any (occursCaseAlt n) alts occursSC n (ProjCase t alts) = occurs n t || any (occursCaseAlt n) alts occursSC n (STerm t) = occurs n t occursSC n _ = False namesSC :: SC -> [Name] namesSC (Case _ alts) = concatMap namesCaseAlt alts namesSC (ProjCase t alts) = names t ++ concatMap namesCaseAlt alts namesSC (STerm t) = names t namesSC _ = [] occursCaseAlt :: Name -> CaseAlt -> Bool occursCaseAlt n (ConCase n' _ _ sc) = n == n' || occursSC n sc occursCaseAlt n (FnCase n' _ sc) = n == n' || occursSC n sc occursCaseAlt n (ConstCase _ sc) = occursSC n sc occursCaseAlt n (SucCase _ sc) = occursSC n sc occursCaseAlt n (DefaultCase sc) = occursSC n sc namesCaseAlt :: CaseAlt -> [Name] namesCaseAlt (ConCase n' _ _ sc) = n' : namesSC sc namesCaseAlt (FnCase n' _ sc) = n' : namesSC sc namesCaseAlt (ConstCase _ sc) = namesSC sc namesCaseAlt (SucCase _ sc) = namesSC sc namesCaseAlt (DefaultCase sc) = namesSC sc occursDef :: Name -> Def -> Bool occursDef n (Function ty tm) = occurs n ty || occurs n tm occursDef n (TyDecl _ ty) = occurs n ty occursDef n (Operator ty _ _) = occurs n ty occursDef n (CaseOp _ ty _ _ _ defs) = occurs n ty || occursSC n (snd (cases_compiletime defs)) namesDef :: Def -> [Name] namesDef (Function ty tm) = names ty ++ names tm namesDef (TyDecl _ ty) = names ty namesDef (Operator ty _ _) = names ty namesDef (CaseOp _ ty _ _ _ defs) = names ty ++ namesSC (snd (cases_compiletime defs)) findOccurs :: Name -> Idris [Name] findOccurs n = do ctxt <- getContext -- A definition calls a function if the function is in the type or RHS of the definition let defs = (map fst . filter (\(n', def) -> n /= n' && occursDef n def) . ctxtAlist) ctxt -- A datatype calls its return defs whoCalls :: Name -> Idris [(Name, [Name])] whoCalls n = do ctxt <- getContext let names = lookupNames n ctxt find nm = do ns <- findOccurs nm return (nm, nub ns) mapM find names callsWho :: Name -> Idris [(Name, [Name])] callsWho n = do ctxt <- getContext let defs = lookupNameDef n ctxt return $ map (\ (n, def) -> (n, nub $ namesDef def)) defs