{-# LANGUAGE PatternGuards, DeriveFunctor, TypeSynonymInstances #-} module Core.CaseTree(CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..), Phase(..), CaseTree, simpleCase, small, namesUsed, findCalls, findUsedArgs) where import Core.TT import Control.Monad.State import Data.Maybe import Data.List hiding (partition) import Debug.Trace data CaseDef = CaseDef [Name] !SC [Term] deriving Show data SC' t = Case Name [CaseAlt' t] -- ^ invariant: lowest tags first | ProjCase t [CaseAlt' t] -- ^ special case for projections | STerm !t | UnmatchedCase String -- ^ error message | ImpossibleCase -- ^ already checked to be impossible deriving (Eq, Ord, Functor) {-! deriving instance Binary SC !-} type SC = SC' Term data CaseAlt' t = ConCase Name Int [Name] !(SC' t) | FnCase Name [Name] !(SC' t) -- ^ reflection function | ConstCase Const !(SC' t) | SucCase Name !(SC' t) | DefaultCase !(SC' t) deriving (Show, Eq, Ord, Functor) {-! deriving instance Binary CaseAlt !-} type CaseAlt = CaseAlt' Term instance Show t => Show (SC' t) where show sc = show' 1 sc where show' i (Case n alts) = "case " ++ show n ++ " of\n" ++ indent i ++ showSep ("\n" ++ indent i) (map (showA i) alts) show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++ showSep ("\n" ++ indent i) (map (showA i) alts) show' i (STerm tm) = show tm show' i (UnmatchedCase str) = "error " ++ show str show' i ImpossibleCase = "impossible" indent i = concat $ take i (repeat " ") showA i (ConCase n t args sc) = show n ++ "(" ++ showSep (", ") (map show args) ++ ") => " ++ show' (i+1) sc showA i (FnCase n args sc) = "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => " ++ show' (i+1) sc showA i (ConstCase t sc) = show t ++ " => " ++ show' (i+1) sc showA i (SucCase n sc) = show n ++ "+1 => " ++ show' (i+1) sc showA i (DefaultCase sc) = "_ => " ++ show' (i+1) sc type CaseTree = SC type Clause = ([Pat], (Term, Term)) type CS = ([Term], Int) instance TermSize SC where termsize n (Case n' as) = termsize n as termsize n (ProjCase n' as) = termsize n as termsize n (STerm t) = termsize n t termsize n _ = 1 instance TermSize CaseAlt where termsize n (ConCase _ _ _ s) = termsize n s termsize n (FnCase _ _ s) = termsize n s termsize n (ConstCase _ s) = termsize n s termsize n (SucCase _ s) = termsize n s termsize n (DefaultCase s) = termsize n s -- simple terms can be inlined trivially - good for primitives in particular -- To avoid duplicating work, don't inline something which uses one -- of its arguments in more than one place small :: Name -> [Name] -> SC -> Bool small n args t = let as = findAllUsedArgs t args in length as == length (nub as) && termsize n t < 10 namesUsed :: SC -> [Name] namesUsed sc = nub $ nu' [] sc where nu' ps (Case n alts) = nub (concatMap (nua ps) alts) \\ [n] nu' ps (ProjCase t alts) = nub $ (nut ps t ++ (concatMap (nua ps) alts)) nu' ps (STerm t) = nub $ nut ps t nu' ps _ = [] nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args nua ps (ConstCase _ sc) = nu' ps sc nua ps (SucCase _ sc) = nu' ps sc nua ps (DefaultCase sc) = nu' ps sc nut ps (P _ n _) | n `elem` ps = [] | otherwise = [n] nut ps (App f a) = nut ps f ++ nut ps a nut ps (Proj t _) = nut ps t nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc nut ps (Bind n b sc) = nut (n:ps) sc nut ps _ = [] -- Return all called functions, and which arguments are used in each argument position -- for the call, in order to help reduce compilation time, and trace all unused -- arguments findCalls :: SC -> [Name] -> [(Name, [[Name]])] findCalls sc topargs = nub $ nu' topargs sc where nu' ps (Case n alts) = nub (concatMap (nua (n : ps)) alts) nu' ps (ProjCase t alts) = nub (nut ps t ++ concatMap (nua ps) alts) nu' ps (STerm t) = nub $ nut ps t nu' ps _ = [] nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) nua ps (ConstCase _ sc) = nu' ps sc nua ps (SucCase _ sc) = nu' ps sc nua ps (DefaultCase sc) = nu' ps sc nut ps (P Ref n _) | n `elem` ps = [] | otherwise = [(n, [])] -- tmp nut ps fn@(App f a) | (P Ref n _, args) <- unApply fn = if n `elem` ps then nut ps f ++ nut ps a else [(n, map argNames args)] ++ concatMap (nut ps) args | (P (TCon _ _) n _, _) <- unApply fn = [] | otherwise = nut ps f ++ nut ps a nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc nut ps (Proj t _) = nut ps t nut ps (Bind n b sc) = nut (n:ps) sc nut ps _ = [] argNames tm = let ns = directUse tm in filter (\x -> x `elem` ns) topargs -- Find names which are used directly (i.e. not in a function call) in a term directUse :: Eq n => TT n -> [n] directUse (P _ n _) = [n] directUse (Bind n (Let t v) sc) = nub $ directUse v ++ (directUse sc \\ [n]) ++ directUse t directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n]) directUse fn@(App f a) | (P Ref n _, args) <- unApply fn = [] -- need to know what n does with them | otherwise = nub $ directUse f ++ directUse a directUse (Proj x i) = nub $ directUse x directUse _ = [] -- Find all directly used arguments (i.e. used but not in function calls) findUsedArgs :: SC -> [Name] -> [Name] findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs) findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where nu' (Case n alts) = n : concatMap nua alts nu' (ProjCase t alts) = directUse t ++ concatMap nua alts nu' (STerm t) = directUse t nu' _ = [] nua (ConCase n i args sc) = nu' sc nua (FnCase n args sc) = nu' sc nua (ConstCase _ sc) = nu' sc nua (SucCase _ sc) = nu' sc nua (DefaultCase sc) = nu' sc data Phase = CompileTime | RunTime deriving (Show, Eq) -- Generate a simple case tree -- Work Right to Left simpleCase :: Bool -> Bool -> Bool -> Phase -> FC -> [([Name], Term, Term)] -> TC CaseDef simpleCase tc cover reflect phase fc cs = sc' tc cover phase fc (filter (\(_, _, r) -> case r of Impossible -> False _ -> True) cs) where sc' tc cover phase fc [] = return $ CaseDef [] (UnmatchedCase "No pattern clauses") [] sc' tc cover phase fc cs = let proj = phase == RunTime pats = map (\ (avs, l, r) -> (avs, toPats reflect tc l, (l, r))) cs chkPats = mapM chkAccessible pats in case chkPats of OK pats -> let numargs = length (fst (head pats)) ns = take numargs args (ns', ps') = order ns pats (tree, st) = runState (match ns' ps' (defaultCase cover)) ([], numargs) t = CaseDef ns (prune proj (depatt ns' tree)) (fst st) in if proj then return (stripLambdas t) else return t Error err -> Error (At fc err) where args = map (\i -> MN i "e") [0..] defaultCase True = STerm Erased defaultCase False = UnmatchedCase "Error" chkAccessible (avs, l, c) | phase == RunTime || reflect = return (l, c) | otherwise = do mapM_ (acc l) avs return (l, c) acc [] n = Error (Inaccessible n) acc (PV x : xs) n | x == n = OK () acc (PCon _ _ ps : xs) n = acc (ps ++ xs) n acc (PSuc p : xs) n = acc (p : xs) n acc (_ : xs) n = acc xs n data Pat = PCon Name Int [Pat] | PConst Const | PV Name | PSuc Pat -- special case for n+1 on Integer | PReflected Name [Pat] | PAny deriving Show -- If there are repeated variables, take the *last* one (could be name shadowing -- in a where clause, so take the most recent). toPats :: Bool -> Bool -> Term -> [Pat] toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where getArgs (App f a) = a : getArgs f getArgs _ = [] toPat :: Bool -> Bool -> [Term] -> [Pat] toPat reflect tc tms = evalState (mapM (\x -> toPat' x []) tms) [] where toPat' (P (DCon t a) n _) args = do args' <- mapM (\x -> toPat' x []) args return $ PCon n t args' -- n + 1 toPat' (P _ (UN "prim__addBigInt") _) [p, Constant (BI 1)] = do p' <- toPat' p [] return $ PSuc p' -- Typecase toPat' (P (TCon t a) n _) args | tc = do args' <- mapM (\x -> toPat' x []) args return $ PCon n t args' toPat' (Constant (AType (ATInt ITNative))) [] | tc = return $ PCon (UN "Int") 1 [] toPat' (Constant (AType ATFloat)) [] | tc = return $ PCon (UN "Float") 2 [] toPat' (Constant (AType (ATInt ITChar))) [] | tc = return $ PCon (UN "Char") 3 [] toPat' (Constant StrType) [] | tc = return $ PCon (UN "String") 4 [] toPat' (Constant PtrType) [] | tc = return $ PCon (UN "Ptr") 5 [] toPat' (Constant (AType (ATInt ITBig))) [] | tc = return $ PCon (UN "Integer") 6 [] toPat' (Constant (AType (ATInt (ITFixed n)))) [] | tc = return $ PCon (UN (fixedN n)) (7 + fromEnum n) [] -- 7-10 inclusive toPat' (P Bound n _) [] = do ns <- get if n `elem` ns then return PAny else do put (n : ns) return (PV n) toPat' (App f a) args = toPat' f (a : args) toPat' (Constant x) [] = return $ PConst x toPat' (Bind n (Pi t) sc) [] | reflect && noOccurrence n sc = do t' <- toPat' t [] sc' <- toPat' sc [] return $ PReflected (UN "->") (t':sc':[]) toPat' (P _ n _) args | reflect = do args' <- mapM (\x -> toPat' x []) args return $ PReflected n args' toPat' t _ = return PAny fixedN IT8 = "Bits8" fixedN IT16 = "Bits16" fixedN IT32 = "Bits32" fixedN IT64 = "Bits64" data Partition = Cons [Clause] | Vars [Clause] deriving Show isVarPat (PV _ : ps , _) = True isVarPat (PAny : ps , _) = True isVarPat _ = False isConPat (PCon _ _ _ : ps, _) = True isConPat (PReflected _ _ : ps, _) = True isConPat (PSuc _ : ps, _) = True isConPat (PConst _ : ps, _) = True isConPat _ = False partition :: [Clause] -> [Partition] partition [] = [] partition ms@(m : _) | isVarPat m = let (vars, rest) = span isVarPat ms in Vars vars : partition rest | isConPat m = let (cons, rest) = span isConPat ms in Cons cons : partition rest partition xs = error $ "Partition " ++ show xs -- reorder the patterns so that the one with most distinct names -- comes next. Take rightmost first, otherwise (i.e. pick value rather -- than dependency) order :: [Name] -> [Clause] -> ([Name], [Clause]) order [] cs = ([], cs) order ns [] = (ns, []) order ns cs = let patnames = transpose (map (zip ns) (map fst cs)) pats' = transpose (sortBy moreDistinct (reverse patnames)) in (getNOrder pats', zipWith rebuild pats' cs) where getNOrder [] = error $ "Failed order on " ++ show (ns, cs) getNOrder (c : _) = map fst c rebuild patnames clause = (map snd patnames, snd clause) moreDistinct xs ys = compare (numNames [] (map snd ys)) (numNames [] (map snd xs)) numNames xs (PCon n _ _ : ps) | not (Left n `elem` xs) = numNames (Left n : xs) ps numNames xs (PConst c : ps) | not (Right c `elem` xs) = numNames (Right c : xs) ps numNames xs (_ : ps) = numNames xs ps numNames xs [] = length xs match :: [Name] -> [Clause] -> SC -- error case -> State CS SC match [] (([], ret) : xs) err = do (ts, v) <- get put (ts ++ (map (fst.snd) xs), v) case snd ret of Impossible -> return ImpossibleCase tm -> return $ STerm tm -- run out of arguments match vs cs err = do let ps = partition cs mixture vs ps err mixture :: [Name] -> [Partition] -> SC -> State CS SC mixture vs [] err = return err mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err conRule vs ms fallthrough mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err varRule vs ms fallthrough data ConType = CName Name Int -- named constructor | CFn Name -- reflected function name | CSuc -- n+1 | CConst Const -- constant, not implemented yet deriving (Show, Eq) data Group = ConGroup ConType -- Constructor [([Pat], Clause)] -- arguments and rest of alternative deriving Show conRule :: [Name] -> [Clause] -> SC -> State CS SC conRule (v:vs) cs err = do groups <- groupCons cs caseGroups (v:vs) groups err caseGroups :: [Name] -> [Group] -> SC -> State CS SC caseGroups (v:vs) gs err = do g <- altGroups gs return $ Case v (sort g) where altGroups [] = return [DefaultCase err] altGroups (ConGroup (CName n i) args : cs) = do g <- altGroup n i args rest <- altGroups cs return (g : rest) altGroups (ConGroup (CFn n) args : cs) = do g <- altFnGroup n args rest <- altGroups cs return (g : rest) altGroups (ConGroup CSuc args : cs) = do g <- altSucGroup args rest <- altGroups cs return (g : rest) altGroups (ConGroup (CConst c) args : cs) = do g <- altConstGroup c args rest <- altGroups cs return (g : rest) altGroup n i gs = do (newArgs, nextCs) <- argsToAlt gs matchCs <- match (newArgs ++ vs) nextCs err return $ ConCase n i newArgs matchCs altFnGroup n gs = do (newArgs, nextCs) <- argsToAlt gs matchCs <- match (newArgs ++ vs) nextCs err return $ FnCase n newArgs matchCs altSucGroup gs = do ([newArg], nextCs) <- argsToAlt gs matchCs <- match (newArg:vs) nextCs err return $ SucCase newArg matchCs altConstGroup n gs = do (_, nextCs) <- argsToAlt gs matchCs <- match vs nextCs err return $ ConstCase n matchCs argsToAlt :: [([Pat], Clause)] -> State CS ([Name], [Clause]) argsToAlt [] = return ([], []) argsToAlt rs@((r, m) : rest) = do newArgs <- getNewVars r return (newArgs, addRs rs) where getNewVars [] = return [] getNewVars ((PV n) : ns) = do v <- getVar "e" nsv <- getNewVars ns return (v : nsv) getNewVars (PAny : ns) = do v <- getVar "i" nsv <- getNewVars ns return (v : nsv) getNewVars (_ : ns) = do v <- getVar "e" nsv <- getNewVars ns return (v : nsv) addRs [] = [] addRs ((r, (ps, res)) : rs) = ((r++ps, res) : addRs rs) uniq i (UN n) = MN i n uniq i n = n getVar :: String -> State CS Name getVar b = do (t, v) <- get; put (t, v+1); return (MN v b) groupCons :: [Clause] -> State CS [Group] groupCons cs = gc [] cs where gc acc [] = return acc gc acc ((p : ps, res) : cs) = do acc' <- addGroup p ps res acc gc acc' cs addGroup p ps res acc = case p of PCon con i args -> return $ addg (CName con i) args (ps, res) acc PConst cval -> return $ addConG cval (ps, res) acc PSuc n -> return $ addg CSuc [n] (ps, res) acc PReflected fn args -> return $ addg (CFn fn) args (ps, res) acc pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)" addg c conargs res [] = [ConGroup c [(conargs, res)]] addg c conargs res (g@(ConGroup c' cs):gs) | c == c' = ConGroup c (cs ++ [(conargs, res)]) : gs | otherwise = g : addg c conargs res gs addConG con res [] = [ConGroup (CConst con) [([], res)]] addConG con res (g@(ConGroup (CConst n) cs) : gs) | con == n = ConGroup (CConst n) (cs ++ [([], res)]) : gs -- | otherwise = g : addConG con res gs addConG con res (g : gs) = g : addConG con res gs varRule :: [Name] -> [Clause] -> SC -> State CS SC varRule (v : vs) alts err = do let alts' = map (repVar v) alts match vs alts' err where repVar v (PV p : ps , (lhs, res)) = (ps, (lhs, subst p (P Bound v Erased) res)) repVar v (PAny : ps , res) = (ps, res) -- fix: case e of S k -> f (S k) ==> case e of S k -> f e depatt :: [Name] -> SC -> SC depatt ns tm = dp [] tm where dp ms (STerm tm) = STerm (applyMaps ms tm) dp ms (Case x alts) = Case x (map (dpa ms x) alts) dp ms sc = sc dpa ms x (ConCase n i args sc) = ConCase n i args (dp ((x, (n, args)) : ms) sc) dpa ms x (FnCase n args sc) = FnCase n args (dp ((x, (n, args)) : ms) sc) dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc) dpa ms x (SucCase n sc) = SucCase n (dp ms sc) dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc) applyMaps ms f@(App _ _) | (P nt cn pty, args) <- unApply f = let args' = map (applyMaps ms) args in applyMap ms nt cn pty args' where applyMap [] nt cn pty args' = mkApp (P nt cn pty) args' applyMap ((x, (n, args)) : ms) nt cn pty args' | and ((length args == length args') : (n == cn) : zipWith same args args') = P Ref x Erased | otherwise = applyMap ms nt cn pty args' same n (P _ n' _) = n == n' same _ _ = False applyMaps ms (App f a) = App (applyMaps ms f) (applyMaps ms a) applyMaps ms t = t -- FIXME: Do this for SucCase too prune :: Bool -- ^ Convert single branches to projections (only useful at runtime) -> SC -> SC prune proj (Case n alts) = let alts' = filter notErased (map pruneAlt alts) in case alts' of [] -> ImpossibleCase as@[ConCase cn i args sc] -> if proj then mkProj n 0 args sc else Case n as as@[SucCase cn sc] -> if proj then mkProj n (-1) [cn] sc else Case n as as@[ConstCase _ sc] -> prune proj sc -- Bit of a hack here! The default case will always be 0, make sure -- it gets caught first. [s@(SucCase _ _), DefaultCase dc] -> Case n [ConstCase (BI 0) dc, s] as -> Case n as where pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc) pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc) pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc) pruneAlt (SucCase n sc) = SucCase n (prune proj sc) pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc) notErased (DefaultCase (STerm Erased)) = False notErased (DefaultCase ImpossibleCase) = False notErased _ = True mkProj n i [] sc = prune proj sc mkProj n i (x : xs) sc = mkProj n (i + 1) xs (projRep x n i sc) -- Change every 'n' in sc to 'n-1' -- mkProjS n cn sc = prune proj (fmap projn sc) where -- projn pn@(P _ n' _) -- | cn == n' = App (App (P Ref (UN "prim__subBigInt") Erased) -- (P Bound n Erased)) (Constant (BI 1)) -- projn t = t projRep :: Name -> Name -> Int -> SC -> SC projRep arg n i (Case x alts) | x == arg = ProjCase (Proj (P Bound n Erased) i) (map (projRepAlt arg n i) alts) | otherwise = Case x (map (projRepAlt arg n i) alts) projRep arg n i (ProjCase t alts) = ProjCase (projRepTm arg n i t) (map (projRepAlt arg n i) alts) projRep arg n i (STerm t) = STerm (projRepTm arg n i t) projRep arg n i c = c -- unmatched projRepAlt arg n i (ConCase cn t args rhs) = ConCase cn t args (projRep arg n i rhs) projRepAlt arg n i (FnCase cn args rhs) = FnCase cn args (projRep arg n i rhs) projRepAlt arg n i (ConstCase t rhs) = ConstCase t (projRep arg n i rhs) projRepAlt arg n i (SucCase sn rhs) = SucCase sn (projRep arg n i rhs) projRepAlt arg n i (DefaultCase rhs) = DefaultCase (projRep arg n i rhs) projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t prune _ t = t stripLambdas :: CaseDef -> CaseDef stripLambdas (CaseDef ns (STerm (Bind x (Lam _) sc)) tm) = stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm) stripLambdas x = x