{-| Module : Idris.Core.CaseTree Description : Module to define and interact with case trees. License : BSD3 Maintainer : The Idris Community. Note: The case-tree elaborator only produces (Case n alts)-cases; in other words, it never inspects anything else than variables. ProjCase is a special powerful case construct that allows inspection of compound terms. Occurrences of ProjCase arise no earlier than in the function `prune` as a means of optimisation of already built case trees. While the intermediate representation (follows in the pipeline, named LExp) allows casing on arbitrary terms, here we choose to maintain the distinction in order to allow for better optimisation opportunities. -} {-# LANGUAGE DeriveFunctor, DeriveGeneric, FlexibleContexts, FlexibleInstances, PatternGuards, TypeSynonymInstances #-} module Idris.Core.CaseTree ( CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..), ErasureInfo , Phase(..), CaseTree, CaseType(..) , simpleCase, small, namesUsed, findCalls, findCalls', findUsedArgs , substSC, substAlt, mkForce ) where import Idris.Core.TT import Control.Monad.Reader import Control.Monad.State import Data.List hiding (partition) import qualified Data.List (partition) import qualified Data.Set as S import GHC.Generics (Generic) data CaseDef = CaseDef [Name] !SC [Term] deriving Show data SC' t = Case CaseType Name [CaseAlt' t] -- ^ invariant: lowest tags first | ProjCase t [CaseAlt' t] -- ^ special case for projections/thunk-forcing before inspection | STerm !t | UnmatchedCase String -- ^ error message | ImpossibleCase -- ^ already checked to be impossible deriving (Eq, Ord, Functor, Generic) {-! deriving instance Binary SC' !-} data CaseType = Updatable | Shared deriving (Eq, Ord, Show, Generic) type SC = SC' Term data CaseAlt' t = ConCase Name Int [Name] !(SC' t) | FnCase Name [Name] !(SC' t) -- ^ reflection function | ConstCase Const !(SC' t) | SucCase Name !(SC' t) | DefaultCase !(SC' t) deriving (Show, Eq, Ord, Functor, Generic) {-! deriving instance Binary CaseAlt' !-} type CaseAlt = CaseAlt' Term instance Show t => Show (SC' t) where show sc = show' 1 sc where show' i (Case up n alts) = "case" ++ u ++ show n ++ " of\n" ++ indent i ++ showSep ("\n" ++ indent i) (map (showA i) alts) where u = case up of Updatable -> "! " Shared -> " " show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++ showSep ("\n" ++ indent i) (map (showA i) alts) show' i (STerm tm) = show tm show' i (UnmatchedCase str) = "error " ++ show str show' i ImpossibleCase = "impossible" indent i = concat $ take i (repeat " ") showA i (ConCase n t args sc) = show n ++ "(" ++ showSep (", ") (map show args) ++ ") => " ++ show' (i+1) sc showA i (FnCase n args sc) = "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => " ++ show' (i+1) sc showA i (ConstCase t sc) = show t ++ " => " ++ show' (i+1) sc showA i (SucCase n sc) = show n ++ "+1 => " ++ show' (i+1) sc showA i (DefaultCase sc) = "_ => " ++ show' (i+1) sc type CaseTree = SC type Clause = ([Pat], (Term, Term)) type CS = ([Term], Int, [(Name, Type)]) instance TermSize SC where termsize n (Case _ n' as) = termsize n as termsize n (ProjCase n' as) = termsize n as termsize n (STerm t) = termsize n t termsize n _ = 1 instance TermSize CaseAlt where termsize n (ConCase _ _ _ s) = termsize n s termsize n (FnCase _ _ s) = termsize n s termsize n (ConstCase _ s) = termsize n s termsize n (SucCase _ s) = termsize n s termsize n (DefaultCase s) = termsize n s -- simple terms can be inlined trivially - good for primitives in particular -- To avoid duplicating work, don't inline something which uses one -- of its arguments in more than one place small :: Name -> [Name] -> SC -> Bool small n args t = let as = findAllUsedArgs t args in length as == length (nub as) && termsize n t < 20 namesUsed :: SC -> [Name] namesUsed sc = nub $ nu' [] sc where nu' ps (Case _ n alts) = nub (concatMap (nua ps) alts) \\ [n] nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts nu' ps (STerm t) = nub $ nut ps t nu' ps _ = [] nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args nua ps (ConstCase _ sc) = nu' ps sc nua ps (SucCase _ sc) = nu' ps sc nua ps (DefaultCase sc) = nu' ps sc nut ps (P _ n _) | n `elem` ps = [] | otherwise = [n] nut ps (App _ f a) = nut ps f ++ nut ps a nut ps (Proj t _) = nut ps t nut ps (Bind n (Let _ t v) sc) = nut ps v ++ nut (n:ps) sc nut ps (Bind n b sc) = nut (n:ps) sc nut ps _ = [] -- | Return all called functions, and which arguments are used -- in each argument position for the call, in order to help reduce -- compilation time, and trace all unused arguments findCalls :: SC -> [Name] -> [(Name, [[Name]])] findCalls = findCalls' False findCalls' :: Bool -> SC -> [Name] -> [(Name, [[Name]])] findCalls' ignoreasserts sc topargs = S.toList $ nu' topargs sc where nu' ps (Case _ n alts) = S.unions $ map (nua (n : ps)) alts nu' ps (ProjCase t alts) = S.unions $ nut ps t : map (nua ps) alts nu' ps (STerm t) = nut ps t nu' ps _ = S.empty nua ps (ConCase n i args sc) = nu' (ps ++ args) sc nua ps (FnCase n args sc) = nu' (ps ++ args) sc nua ps (ConstCase _ sc) = nu' ps sc nua ps (SucCase _ sc) = nu' ps sc nua ps (DefaultCase sc) = nu' ps sc nut ps (P _ n _) | n `elem` ps = S.empty | otherwise = S.singleton (n, []) nut ps fn@(App _ f a) | (P _ n _, args) <- unApply fn = if ignoreasserts && n == sUN "assert_total" then S.empty else if n `elem` ps then S.union (nut ps f) (nut ps a) else S.insert (n, map argNames args) (S.unions $ map (nut ps) args) | (P (TCon _ _) n _, _) <- unApply fn = S.empty | otherwise = S.union (nut ps f) (nut ps a) nut ps (Bind n (Let _ t v) sc) = S.union (nut ps v) (nut (n:ps) sc) nut ps (Proj t _) = nut ps t nut ps (Bind n b sc) = nut (n:ps) sc nut ps _ = S.empty argNames tm = let ns = directUse tm in filter (\x -> x `elem` ns) topargs -- Find names which are used directly (i.e. not in a function call) in a term directUse :: TT Name -> [Name] directUse (P _ n _) = [n] directUse (Bind n (Let _ t v) sc) = nub $ directUse v ++ (directUse sc \\ [n]) ++ directUse t directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n]) directUse fn@(App _ f a) | (P Ref (UN pfk) _, [App _ e w]) <- unApply fn, pfk == txt "prim_fork" = directUse e ++ directUse w -- HACK so that fork works | (P Ref (UN fce) _, [_, _, a]) <- unApply fn, fce == txt "Force" = directUse a -- forcing a value counts as a use | (P Ref n _, args) <- unApply fn = [] -- need to know what n does with them | (P (TCon _ _) n _, args) <- unApply fn = [] -- type constructors not used at runtime | otherwise = nub $ directUse f ++ directUse a directUse (Proj x i) = nub $ directUse x directUse _ = [] -- Find all directly used arguments (i.e. used but not in function calls) findUsedArgs :: SC -> [Name] -> [Name] findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs) findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where nu' (Case _ n alts) = n : concatMap nua alts nu' (ProjCase t alts) = directUse t ++ concatMap nua alts nu' (STerm t) = directUse t nu' _ = [] nua (ConCase n i args sc) = nu' sc nua (FnCase n args sc) = nu' sc nua (ConstCase _ sc) = nu' sc nua (SucCase _ sc) = nu' sc nua (DefaultCase sc) = nu' sc -- Return whether name is used anywhere in a case tree isUsed :: SC -> Name -> Bool isUsed sc n = used sc where used (Case _ n' alts) = n == n' || any usedA alts used (ProjCase t alts) = n `elem` freeNames t || any usedA alts used (STerm t) = n `elem` freeNames t used _ = False usedA (ConCase _ _ args sc) = used sc usedA (FnCase _ args sc) = used sc usedA (ConstCase _ sc) = used sc usedA (SucCase _ sc) = used sc usedA (DefaultCase sc) = used sc type ErasureInfo = Name -> [Int] -- name to list of inaccessible arguments; empty list if name not found type CaseBuilder a = ReaderT ErasureInfo (State CS) a runCaseBuilder :: ErasureInfo -> CaseBuilder a -> (CS -> (a, CS)) runCaseBuilder ei bld = runState $ runReaderT bld ei data Phase = CoverageCheck [Int] -- list of positions explicitly given | CompileTime | RunTime deriving (Show, Eq) -- Generate a simple case tree -- Work Right to Left simpleCase :: Bool -> SC -> Bool -> Phase -> FC -> -- Following two can be empty lists when Phase = CoverageCheck [Int] -> -- Inaccessible argument positions [(Type, Bool)] -> -- (Argument type, whether it's canonical) [([Name], Term, Term)] -> ErasureInfo -> TC CaseDef simpleCase tc defcase reflect phase fc inacc argtys cs erInfo = sc' tc defcase phase fc (filter (\(_, _, r) -> case r of Impossible -> False _ -> True) cs) where sc' tc defcase phase fc [] = return $ CaseDef [] (UnmatchedCase (show fc ++ ":No pattern clauses")) [] sc' tc defcase phase fc cs = let proj = phase == RunTime pats = map (\ (avs, l, r) -> (avs, toPats reflect tc l, (l, r))) cs chkPats = mapM chkAccessible pats in case chkPats of OK pats -> let numargs = length (fst (head pats)) ns = take numargs args (ns', ps') = order phase [(n, i `elem` inacc) | (i,n) <- zip [0..] ns] pats (map snd argtys) (tree, st) = runCaseBuilder erInfo (match ns' ps' defcase) ([], numargs, []) sc = removeUnreachable (prune proj (depatt ns' tree)) t = CaseDef ns sc (fstT st) in if proj then return (stripLambdas t) else return t Error err -> Error (At fc err) where args = map (\i -> sMN i "e") [0..] fstT (x, _, _) = x -- Check that all pattern variables are reachable by a case split -- Otherwise, they won't make sense on the RHS. chkAccessible (avs, l, c) | phase /= CompileTime || reflect = return (l, c) | otherwise = do mapM_ (acc l) avs return (l, c) acc [] n = Error (Inaccessible n) acc (PV x t : xs) n | x == n = OK () acc (PCon _ _ _ ps : xs) n = acc (ps ++ xs) n acc (PSuc p : xs) n = acc (p : xs) n acc (_ : xs) n = acc xs n data Pat = PCon Bool Name Int [Pat] | PConst Const | PInferred Pat | PV Name Type | PSuc Pat -- special case for n+1 on Integer | PReflected Name [Pat] | PAny | PTyPat -- typecase, not allowed, inspect last deriving Show -- If there are repeated variables, take the *last* one (could be name shadowing -- in a where clause, so take the most recent). toPats :: Bool -> Bool -> Term -> [Pat] toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where getArgs (App _ f a) = a : getArgs f getArgs _ = [] toPat :: Bool -> Bool -> [Term] -> [Pat] toPat reflect tc = map $ toPat' [] where toPat' [_,_,arg] (P (DCon t a uniq) nm@(UN n) _) | n == txt "Delay" = PCon uniq nm t [PAny, PAny, toPat' [] arg] toPat' args (P (DCon t a uniq) nm@(NS (UN n) [own]) _) | n == txt "Read" && own == txt "Ownership" = PCon False nm t (map shareCons (map (toPat' []) args)) where shareCons (PCon _ n i ps) = PCon False n i (map shareCons ps) shareCons p = p toPat' args (P (DCon t a uniq) n _) = PCon uniq n t $ map (toPat' []) args -- n + 1 toPat' [p, Constant (BI 1)] (P _ (UN pabi) _) | pabi == txt "prim__addBigInt" = PSuc $ toPat' [] p toPat' [] (P Bound n ty) = PV n ty toPat' args (App _ f a) = toPat' (a : args) f toPat' args (Inferred tm) = PInferred (toPat' args tm) toPat' [] (Constant x) | isTypeConst x = PTyPat | otherwise = PConst x toPat' [] (Bind n (Pi _ _ t _) sc) | reflect && noOccurrence n sc = PReflected (sUN "->") [toPat' [] t, toPat' [] sc] toPat' args (P _ n _) | reflect = PReflected n $ map (toPat' []) args toPat' _ t = PAny data Partition = Cons [Clause] | Vars [Clause] deriving Show isVarPat (PV _ _ : ps , _) = True isVarPat (PAny : ps , _) = True isVarPat (PTyPat : ps , _) = True isVarPat _ = False isConPat (PCon _ _ _ _ : ps, _) = True isConPat (PReflected _ _ : ps, _) = True isConPat (PSuc _ : ps, _) = True isConPat (PConst _ : ps, _) = True isConPat _ = False partition :: [Clause] -> [Partition] partition [] = [] partition ms@(m : _) | isVarPat m = let (vars, rest) = span isVarPat ms in Vars vars : partition rest | isConPat m = let (cons, rest) = span isConPat ms in Cons cons : partition rest partition xs = error $ "Partition " ++ show xs -- reorder the patterns so that the one with most distinct names -- comes next. Take rightmost first, otherwise (i.e. pick value rather -- than dependency) -- -- The first argument means [(Name, IsInaccessible)]. order :: Phase -> [(Name, Bool)] -> [Clause] -> [Bool] -> ([Name], [Clause]) -- do nothing at compile time: FIXME (EB): Put this in after checking -- implications for Franck's reflection work... see issue 3233 -- order CompileTime ns cs _ = (map fst ns, cs) order _ [] cs cans = ([], cs) order _ ns' [] cans = (map fst ns', []) order (CoverageCheck pos) ns' cs cans = let ns_out = pick 0 [] (map fst ns') cs_out = map pickClause cs in (ns_out, cs_out) where pickClause (pats, def) = (pick 0 [] pats, def) -- Order the list so that things in a position in 'pos' are in the first -- part, then all the other things later. Otherwise preserve order. pick i skipped [] = reverse skipped pick i skipped (x : xs) | i `elem` pos = x : pick (i + 1) skipped xs | otherwise = pick (i + 1) (x : skipped) xs order phase ns' cs cans = let patnames = transpose (map (zip ns') (map (zip cans) (map fst cs))) -- only sort the arguments where there is no clash in -- constructor tags between families, the argument type is canonical, -- and no constructor/constant -- clash, because otherwise we can't reliable make the -- case distinction on evaluation (patnames_ord, patnames_rest) = Data.List.partition (noClash . map snd) patnames patnames_ord' = case phase of CompileTime -> patnames_ord -- reversing tends to make better case trees -- and helps erasure RunTime -> reverse patnames_ord pats' = transpose (sortBy moreDistinct patnames_ord' ++ patnames_rest) in (getNOrder pats', zipWith rebuild pats' cs) where getNOrder [] = error $ "Failed order on " ++ show (map fst ns', cs) getNOrder (c : _) = map (fst . fst) c rebuild patnames clause = (map (snd . snd) patnames, snd clause) noClash [] = True noClash ((can, p) : ps) = can && not (any (clashPat p) (map snd ps)) && noClash ps clashPat (PCon _ _ _ _) (PConst _) = True clashPat (PConst _) (PCon _ _ _ _) = True clashPat (PCon _ _ _ _) (PSuc _) = True clashPat (PSuc _) (PCon _ _ _ _) = True clashPat (PCon _ n i _) (PCon _ n' i' _) | i == i' = n /= n' clashPat _ _ = False -- this compares (+isInaccessible, -numberOfCases) moreDistinct xs ys = compare (snd . fst . head $ xs, numNames [] (map snd ys)) (snd . fst . head $ ys, numNames [] (map snd xs)) numNames xs ((_, PCon _ n _ _) : ps) | not (Left n `elem` xs) = numNames (Left n : xs) ps numNames xs ((_, PConst c) : ps) | not (Right c `elem` xs) = numNames (Right c : xs) ps numNames xs (_ : ps) = numNames xs ps numNames xs [] = length xs -- Reorder the patterns in the clause so that the PInferred patterns come -- last. Also strip 'PInferred' from the top level patterns so that we can -- go ahead and match. orderByInf :: [Name] -> [Clause] -> ([Name], [Clause]) orderByInf vs cs = let alwaysInf = getInf cs in (selectInf alwaysInf vs, map deInf (map (selectExp alwaysInf) cs)) where getInf [] = [] getInf [(pats, def)] = infPos 0 pats getInf ((pats, def) : cs) = infPos 0 pats `intersect` getInf cs selectExp :: [Int] -> Clause -> Clause selectExp infs (pats, def) = let (notInf, inf) = splitPats 0 infs [] [] pats in (notInf ++ inf, def) selectInf :: [Int] -> [a] -> [a] selectInf infs ns = let (notInf, inf) = splitPats 0 infs [] [] ns in notInf ++ inf splitPats i infpos notInf inf [] = (reverse notInf, reverse inf) splitPats i infpos notInf inf (p : ps) | i `elem` infpos = splitPats (i + 1) infpos notInf (p : inf) ps | otherwise = splitPats (i + 1) infpos (p : notInf) inf ps infPos i [] = [] infPos i (PInferred p : ps) = i : infPos (i + 1) ps infPos i (_ : ps) = infPos (i + 1) ps deInf (pats, def) = (map deInfPat pats, def) deInfPat (PInferred p) = p deInfPat p = p match :: [Name] -> [Clause] -> SC -- error case -> CaseBuilder SC match [] (([], ret) : xs) err = do (ts, v, ntys) <- get put (ts ++ (map (fst.snd) xs), v, ntys) case snd ret of Impossible -> return ImpossibleCase tm -> return $ STerm tm -- run out of arguments match vs cs err = do let (vs', de_inf) = orderByInf vs cs ps = partition de_inf mixture vs' ps err mixture :: [Name] -> [Partition] -> SC -> CaseBuilder SC mixture vs [] err = return err mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err conRule vs ms fallthrough mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err varRule vs ms fallthrough -- Return the list of inaccessible arguments of a data constructor. inaccessibleArgs :: Name -> CaseBuilder [Int] inaccessibleArgs n = do getInaccessiblePositions <- ask -- this function is the only thing in the environment return $ getInaccessiblePositions n data ConType = CName Name Int -- named constructor | CFn Name -- reflected function name | CSuc -- n+1 | CConst Const -- constant, not implemented yet deriving (Show, Eq) data Group = ConGroup Bool -- Uniqueness flag ConType -- Constructor [([Pat], Clause)] -- arguments and rest of alternative deriving Show conRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC conRule (v:vs) cs err = do groups <- groupCons cs caseGroups (v:vs) groups err caseGroups :: [Name] -> [Group] -> SC -> CaseBuilder SC caseGroups (v:vs) gs err = do g <- altGroups gs return $ Case (getShared gs) v (sort g) where getShared (ConGroup True _ _ : _) = Updatable getShared _ = Shared altGroups [] = return [DefaultCase err] altGroups (ConGroup _ (CName n i) args : cs) = (:) <$> altGroup n i args <*> altGroups cs altGroups (ConGroup _ (CFn n) args : cs) = (:) <$> altFnGroup n args <*> altGroups cs altGroups (ConGroup _ CSuc args : cs) = (:) <$> altSucGroup args <*> altGroups cs altGroups (ConGroup _ (CConst c) args : cs) = (:) <$> altConstGroup c args <*> altGroups cs altGroup n i args = do inacc <- inaccessibleArgs n (newVars, accVars, inaccVars, nextCs) <- argsToAlt inacc args matchCs <- match (accVars ++ vs ++ inaccVars) nextCs err return $ ConCase n i newVars matchCs altFnGroup n args = do (newVars, _, [], nextCs) <- argsToAlt [] args matchCs <- match (newVars ++ vs) nextCs err return $ FnCase n newVars matchCs altSucGroup args = do ([newVar], _, [], nextCs) <- argsToAlt [] args matchCs <- match (newVar:vs) nextCs err return $ SucCase newVar matchCs altConstGroup n args = do (_, _, [], nextCs) <- argsToAlt [] args matchCs <- match vs nextCs err return $ ConstCase n matchCs -- Returns: -- * names of all variables arising from match -- * names of accessible variables (subset of all variables) -- * names of inaccessible variables (subset of all variables) -- * clauses corresponding to (accVars ++ origVars ++ inaccVars) argsToAlt :: [Int] -> [([Pat], Clause)] -> CaseBuilder ([Name], [Name], [Name], [Clause]) argsToAlt _ [] = return ([], [], [], []) argsToAlt inacc rs@((r, m) : rest) = do newVars <- getNewVars r let (accVars, inaccVars) = partitionAcc newVars return (newVars, accVars, inaccVars, addRs rs) where -- Create names for new variables arising from the given patterns. getNewVars :: [Pat] -> CaseBuilder [Name] getNewVars [] = return [] getNewVars ((PV n t) : ns) = do v <- getVar "e" nsv <- getNewVars ns -- Record the type of the variable. -- -- It seems that the ordering is not important -- and we can put (v,t) always in front of "ntys" -- (the varName-type pairs seem to represent a mapping). -- -- The code that reads this is currently -- commented out, anyway. (cs, i, ntys) <- get put (cs, i, (v, t) : ntys) return (v : nsv) getNewVars (PAny : ns) = (:) <$> getVar "i" <*> getNewVars ns getNewVars (PTyPat : ns) = (:) <$> getVar "t" <*> getNewVars ns getNewVars (_ : ns) = (:) <$> getVar "e" <*> getNewVars ns -- Partition a list of things into (accessible, inaccessible) things, -- according to the list of inaccessible indices. partitionAcc xs = ( [x | (i,x) <- zip [0..] xs, i `notElem` inacc] , [x | (i,x) <- zip [0..] xs, i `elem` inacc] ) addRs [] = [] addRs ((r, (ps, res)) : rs) = ((acc++ps++inacc, res) : addRs rs) where (acc, inacc) = partitionAcc r getVar :: String -> CaseBuilder Name getVar b = do (t, v, ntys) <- get; put (t, v+1, ntys); return (sMN v b) groupCons :: [Clause] -> CaseBuilder [Group] groupCons cs = gc [] cs where gc acc [] = return acc gc acc ((p : ps, res) : cs) = do acc' <- addGroup p ps res acc gc acc' cs addGroup p ps res acc = case p of PCon uniq con i args -> return $ addg uniq (CName con i) args (ps, res) acc PConst cval -> return $ addConG cval (ps, res) acc PSuc n -> return $ addg False CSuc [n] (ps, res) acc PReflected fn args -> return $ addg False (CFn fn) args (ps, res) acc pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)" addg uniq c conargs res [] = [ConGroup uniq c [(conargs, res)]] addg uniq c conargs res (g@(ConGroup _ c' cs):gs) | c == c' = ConGroup uniq c (cs ++ [(conargs, res)]) : gs | otherwise = g : addg uniq c conargs res gs addConG con res [] = [ConGroup False (CConst con) [([], res)]] addConG con res (g@(ConGroup False (CConst n) cs) : gs) | con == n = ConGroup False (CConst n) (cs ++ [([], res)]) : gs -- | otherwise = g : addConG con res gs addConG con res (g : gs) = g : addConG con res gs varRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC varRule (v : vs) alts err = do alts' <- mapM (repVar v) alts match vs alts' err where repVar v (PV p ty : ps , (lhs, res)) = do (cs, i, ntys) <- get put (cs, i, (v, ty) : ntys) return (ps, (lhs, subst p (P Bound v ty) res)) repVar v (PAny : ps , res) = return (ps, res) repVar v (PTyPat : ps , res) = return (ps, res) -- fix: case e of S k -> f (S k) ==> case e of S k -> f e depatt :: [Name] -> SC -> SC depatt ns tm = dp [] tm where dp ms (STerm tm) = STerm (applyMaps ms tm) dp ms (Case up x alts) = Case up x (map (dpa ms x) alts) dp ms sc = sc dpa ms x (ConCase n i args sc) = ConCase n i args (dp ((x, (n, args)) : ms) sc) dpa ms x (FnCase n args sc) = FnCase n args (dp ((x, (n, args)) : ms) sc) dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc) dpa ms x (SucCase n sc) = SucCase n (dp ms sc) dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc) applyMaps ms f@(App _ _ _) | (P nt cn pty, args) <- unApply f = let args' = map (applyMaps ms) args in applyMap ms nt cn pty args' where applyMap [] nt cn pty args' = mkApp (P nt cn pty) args' applyMap ((x, (n, args)) : ms) nt cn pty args' | and ((length args == length args') : (n == cn) : zipWith same args args') = P Ref x Erased | otherwise = applyMap ms nt cn pty args' same n (P _ n' _) = n == n' same _ _ = False applyMaps ms (App s f a) = App s (applyMaps ms f) (applyMaps ms a) applyMaps ms t = t -- FIXME: Do this for SucCase too -- Issue #1719 on the issue tracker: https://github.com/idris-lang/Idris-dev/issues/1719 prune :: Bool -- ^ Convert single branches to projections (only useful at runtime) -> SC -> SC prune proj (Case up n alts) = case alts' of [] -> ImpossibleCase -- Projection transformations prevent us from seeing some uses of ctor fields -- because they delete information about which ctor is being used. -- Consider: -- f (X x) = ... x ... -- vs. -- f x = ... x!0 ... -- -- Hence, we disable this step. -- TODO: re-enable this in toIR -- -- as@[ConCase cn i args sc] -- | proj -> mkProj n 0 args (prune proj sc) -- mkProj n i xs sc = foldr (\x -> projRep x n i) sc xs -- If none of the args are used in the sc, however, we can just replace it -- with sc as@[ConCase cn i args sc] | proj -> let sc' = prune proj sc in if any (isUsed sc') args then Case up n [ConCase cn i args sc'] else sc' [SucCase cn sc] | proj -> projRep cn n (-1) $ prune proj sc [ConstCase _ sc] -> prune proj sc -- Bit of a hack here! The default case will always be 0, make sure -- it gets caught first. [s@(SucCase _ _), DefaultCase dc] -> Case up n [ConstCase (BI 0) dc, s] as -> Case up n as where alts' = filter (not . erased) $ map pruneAlt alts pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc) pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc) pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc) pruneAlt (SucCase n sc) = SucCase n (prune proj sc) pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc) erased (DefaultCase (STerm Erased)) = True erased (DefaultCase ImpossibleCase) = True erased _ = False projRep :: Name -> Name -> Int -> SC -> SC projRep arg n i (Case up x alts) | x == arg = ProjCase (Proj (P Bound n Erased) i) $ map (projRepAlt arg n i) alts projRep arg n i (Case up x alts) = Case up x (map (projRepAlt arg n i) alts) projRep arg n i (ProjCase t alts) = ProjCase (projRepTm arg n i t) $ map (projRepAlt arg n i) alts projRep arg n i (STerm t) = STerm (projRepTm arg n i t) projRep arg n i c = c projRepAlt arg n i (ConCase cn t args rhs) = ConCase cn t args (projRep arg n i rhs) projRepAlt arg n i (FnCase cn args rhs) = FnCase cn args (projRep arg n i rhs) projRepAlt arg n i (ConstCase t rhs) = ConstCase t (projRep arg n i rhs) projRepAlt arg n i (SucCase sn rhs) = SucCase sn (projRep arg n i rhs) projRepAlt arg n i (DefaultCase rhs) = DefaultCase (projRep arg n i rhs) projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t prune _ t = t -- Remove any branches we can't reach because of variables we've already -- tested removeUnreachable :: SC -> SC removeUnreachable sc = ru [] sc where -- keep a mapping from variable names, to the constructor tag we've -- already checked it as in this branch ru :: [(Name, Int)] -> SC -> SC ru checked (Case t n alts) = let alts' = map (ruAlt checked n) (dropImpossible (lookup n checked) alts) in Case t n alts' ru checked t = t dropImpossible Nothing alts = alts dropImpossible (Just t) [] = [] dropImpossible (Just t) (ConCase con tag args sc : rest) | t == tag = [ConCase con tag args sc] -- must be this case | otherwise = dropImpossible (Just t) rest -- can't be this case dropImpossible (Just t) (c : rest) = c : dropImpossible (Just t) rest ruAlt :: [(Name, Int)] -> Name -> CaseAlt -> CaseAlt ruAlt checked var (ConCase con tag args sc) = let checked' = dropChecked args (updateChecked var tag checked) sc' = ru checked' sc in ConCase con tag args sc' ruAlt checked var (FnCase n args sc) = let checked' = dropChecked [var] checked sc' = ru checked' sc in FnCase n args sc' ruAlt checked var (ConstCase c sc) = let checked' = dropChecked [var] checked sc' = ru checked' sc in ConstCase c sc' ruAlt checked var (SucCase n sc) = let checked' = dropChecked [var] checked sc' = ru checked' sc in SucCase n sc' ruAlt checked var (DefaultCase sc) = let checked' = dropChecked [var] checked sc' = ru checked' sc in DefaultCase sc' updateChecked :: Name -> Int -> [(Name, Int)] -> [(Name, Int)] updateChecked n i checked = (n, i) : filter (\x -> fst x /= n) checked dropChecked :: [Name] -> [(Name, Int)] -> [(Name, Int)] dropChecked ns checked = filter (\x -> fst x `notElem` ns) checked stripLambdas :: CaseDef -> CaseDef stripLambdas (CaseDef ns (STerm (Bind x (Lam _ _) sc)) tm) = stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm) stripLambdas x = x substSC :: Name -> Name -> SC -> SC substSC n repl (Case up n' alts) | n == n' = Case up repl (map (substAlt n repl) alts) | otherwise = Case up n' (map (substAlt n repl) alts) substSC n repl (STerm t) = STerm $ subst n (P Bound repl Erased) t substSC n repl (UnmatchedCase errmsg) = UnmatchedCase errmsg substSC n repl ImpossibleCase = ImpossibleCase substSC n repl sc = error $ "unsupported in substSC: " ++ show sc substAlt :: Name -> Name -> CaseAlt -> CaseAlt substAlt n repl (ConCase cn a ns sc) = ConCase cn a ns (substSC n repl sc) substAlt n repl (FnCase fn ns sc) = FnCase fn ns (substSC n repl sc) substAlt n repl (ConstCase c sc) = ConstCase c (substSC n repl sc) substAlt n repl (SucCase n' sc) | n == n' = SucCase n (substSC n repl sc) | otherwise = SucCase n' (substSC n repl sc) substAlt n repl (DefaultCase sc) = DefaultCase (substSC n repl sc) -- mkForce n' n t updates the tree t under the assumption that -- n' = force n (so basically updating n to n') mkForce :: Name -> Name -> SC -> SC mkForce = mkForceSC where mkForceSC n arg (Case up x alts) | x == arg = Case up n $ map (mkForceAlt n arg) alts mkForceSC n arg (Case up x alts) = Case up x (map (mkForceAlt n arg) alts) mkForceSC n arg (ProjCase t alts) = ProjCase t $ map (mkForceAlt n arg) alts mkForceSC n arg c = c mkForceAlt n arg (ConCase cn t args rhs) = ConCase cn t args (mkForceSC n arg rhs) mkForceAlt n arg (FnCase cn args rhs) = FnCase cn args (mkForceSC n arg rhs) mkForceAlt n arg (ConstCase t rhs) = ConstCase t (mkForceSC n arg rhs) mkForceAlt n arg (SucCase sn rhs) = SucCase sn (mkForceSC n arg rhs) mkForceAlt n arg (DefaultCase rhs) = DefaultCase (mkForceSC n arg rhs)