{-# LANGUAGE PatternGuards #-} module Core.Unify(match_unify, unify, Fails) where import Core.TT import Core.Evaluate import Control.Monad import Control.Monad.State import Data.List import Debug.Trace -- Unification is applied inside the theorem prover. We're looking for holes -- which can be filled in, by matching one term's normal form against another. -- Returns a list of hole names paired with the term which solves them, and -- a list of things which need to be injective. -- terms which need to be injective, with the things we're trying to unify -- at the time type Injs = [(TT Name, TT Name, TT Name)] type Fails = [(TT Name, TT Name, Env, Err)] data UInfo = UI Int Fails deriving Show data UResult a = UOK a | UPartOK a | UFail Err -- Solve metavariables by matching terms against each other -- Not really unification, of course! match_unify :: Context -> Env -> TT Name -> TT Name -> [Name] -> [Name] -> TC [(Name, TT Name)] match_unify ctxt env topx topy dont holes = case runStateT (un [] topx topy) (UI 0 []) of OK (v, UI _ []) -> return (filter notTrivial v) res -> let topxn = normalise ctxt env topx topyn = normalise ctxt env topy in case runStateT (un [] topxn topyn) (UI 0 []) of OK (v, UI _ fails) -> return (filter notTrivial v) Error e -> -- just normalise the term we're matching against case runStateT (un [] topxn topy) (UI 0 []) of OK (v, UI _ fails) -> return (filter notTrivial v) _ -> tfail e where un names (P _ x _) tm | holeIn env x || x `elem` holes = do sc 1; checkCycle names (x, tm) un names tm (P _ y _) | holeIn env y || y `elem` holes = do sc 1; checkCycle names (y, tm) un bnames (V i) (P _ x _) | fst (bnames!!i) == x || snd (bnames!!i) == x = do sc 1; return [] un bnames (P _ x _) (V i) | fst (bnames!!i) == x || snd (bnames!!i) == x = do sc 1; return [] un bnames (Bind x bx sx) (Bind y by sy) = do h1 <- uB bnames bx by h2 <- un ((x, y) : bnames) sx sy combine bnames h1 h2 un names (App fx ax) (App fy ay) = do hf <- un names fx fy ha <- un names ax ay combine names hf ha un names x y | OK True <- convEq' ctxt x y = do sc 1; return [] | otherwise = do UI s f <- get let r = recoverable x y let err = CantUnify r topx topy (CantUnify r x y (Msg "") [] s) (errEnv env) s if (not r) then lift $ tfail err else do put (UI s ((x, y, env, err) : f)) lift $ tfail err uB bnames (Let tx vx) (Let ty vy) = do h1 <- un bnames tx ty h2 <- un bnames vx vy combine bnames h1 h2 uB bnames (Lam tx) (Lam ty) = un bnames tx ty uB bnames (Pi tx) (Pi ty) = un bnames tx ty uB bnames x y = do UI s f <- get let r = recoverable (binderTy x) (binderTy y) let err = CantUnify r topx topy (CantUnify r (binderTy x) (binderTy y) (Msg "") [] s) (errEnv env) s put (UI s ((binderTy x, binderTy y, env, err) : f)) return [] -- TODO: there's an annoying amount of repetition between this and the -- main unification function. Consider lifting it out. sc i = do UI s f <- get put (UI (s+i) f) unifyFail x y = do UI s f <- get let r = recoverable x y let err = CantUnify r topx topy (CantUnify r x y (Msg "") [] s) (errEnv env) s put (UI s ((x, y, env, err) : f)) lift $ tfail err combine bnames as [] = return as combine bnames as ((n, t) : bs) = case lookup n as of Nothing -> combine bnames (as ++ [(n,t)]) bs Just t' -> do ns <- un bnames t t' -- make sure there's n mapping from n in ns let ns' = filter (\ (x, _) -> x/=n) ns sc 1 combine bnames as (ns' ++ bs) checkCycle ns p@(x, P _ _ _) = return [p] checkCycle ns (x, tm) | not (x `elem` freeNames tm) = checkScope ns (x, tm) | otherwise = lift $ tfail (InfiniteUnify x tm (errEnv env)) checkScope ns (x, tm) = case boundVs (envPos x 0 env) tm of [] -> return [(x, tm)] (i:_) -> lift $ tfail (UnifyScope x (fst (ns!!i)) (inst ns tm) (errEnv env)) where inst [] tm = tm inst ((n, _) : ns) tm = inst ns (substV (P Bound n Erased) tm) notTrivial (x, P _ x' _) = x /= x' notTrivial _ = True expandLets env (x, tm) = (x, doSubst (reverse env) tm) where doSubst [] tm = tm doSubst ((n, Let v t) : env) tm = doSubst env (subst n v tm) doSubst (_ : env) tm = doSubst env tm unify :: Context -> Env -> TT Name -> TT Name -> [Name] -> [Name] -> TC ([(Name, TT Name)], Fails) unify ctxt env topx topy dont holes = -- trace ("Unifying " ++ show (topx, topy)) $ -- don't bother if topx and topy are different at the head case runStateT (un False [] topx topy) (UI 0 []) of OK (v, UI _ []) -> return (filter notTrivial v, []) res -> let topxn = normalise ctxt env topx topyn = normalise ctxt env topy in -- trace ("Unifying " ++ show (topx, topy) ++ "\n\n==>\n" ++ show (topxn, topyn) ++ "\n\n" ++ show res ++ "\n\n") $ case runStateT (un False [] topxn topyn) (UI 0 []) of OK (v, UI _ fails) -> return (filter notTrivial v, reverse fails) -- Error e@(CantUnify False _ _ _ _ _) -> tfail e Error e -> tfail e where headDiff (P (DCon _ _) x _) (P (DCon _ _) y _) = x /= y headDiff (P (TCon _ _) x _) (P (TCon _ _) y _) = x /= y headDiff _ _ = False injective (P (DCon _ _) _ _) = True injective (P (TCon _ _) _ _) = True -- injective (App f (P _ _ _)) = injective f -- injective (App f (Constant _)) = injective f injective (App f a) = injective f -- && injective a injective _ = False notP (P _ _ _) = False notP _ = True sc i = do UI s f <- get put (UI (s+i) f) errors = do UI s f <- get return (not (null f)) uplus u1 u2 = do UI s f <- get r <- u1 UI s f' <- get if (length f == length f') then return r else do put (UI s f); u2 un :: Bool -> [(Name, Name)] -> TT Name -> TT Name -> StateT UInfo TC [(Name, TT Name)] un = un' -- un fn names x y -- = let (xf, _) = unApply x -- (yf, _) = unApply y in -- if headDiff xf yf then unifyFail x y else -- uplus (un' fn names x y) -- (un' fn names (hnf ctxt env x) (hnf ctxt env y)) un' :: Bool -> [(Name, Name)] -> TT Name -> TT Name -> StateT UInfo TC [(Name, TT Name)] un' fn names x y | x == y = return [] -- shortcut un' fn names topx@(P (DCon _ _) x _) topy@(P (DCon _ _) y _) | x /= y = unifyFail topx topy un' fn names topx@(P (TCon _ _) x _) topy@(P (TCon _ _) y _) | x /= y = unifyFail topx topy un' fn names topx@(P (DCon _ _) x _) topy@(P (TCon _ _) y _) = unifyFail topx topy un' fn names topx@(P (TCon _ _) x _) topy@(P (DCon _ _) y _) = unifyFail topx topy un' fn names topx@(Constant _) topy@(P (TCon _ _) y _) = unifyFail topx topy un' fn names topx@(P (TCon _ _) x _) topy@(Constant _) = unifyFail topx topy un' fn bnames tx@(P _ x _) ty@(P _ y _) | (x,y) `elem` bnames || x == y = do sc 1; return [] | injective tx && not (holeIn env y || y `elem` holes) = unifyTmpFail tx ty | injective ty && not (holeIn env x || x `elem` holes) = unifyTmpFail tx ty un' fn bnames xtm@(P _ x _) tm | holeIn env x || x `elem` holes = do UI s f <- get -- injectivity check if (notP tm && fn) -- trace (show (x, tm, normalise ctxt env tm)) $ -- put (UI s ((tm, topx, topy) : i) f) then unifyTmpFail xtm tm else do sc 1 checkCycle bnames (x, tm) | not (injective xtm) && injective tm = unifyFail xtm tm un' fn bnames tm ytm@(P _ y _) | holeIn env y || y `elem` holes = do UI s f <- get -- injectivity check if (notP tm && fn) -- trace (show (y, tm, normalise ctxt env tm)) $ -- put (UI s ((tm, topx, topy) : i) f) then unifyTmpFail tm ytm else do sc 1 checkCycle bnames (y, tm) | not (injective ytm) && injective tm = unifyFail ytm tm un' fn bnames (V i) (P _ x _) | fst (bnames!!i) == x || snd (bnames!!i) == x = do sc 1; return [] un' fn bnames (P _ x _) (V i) | fst (bnames!!i) == x || snd (bnames!!i) == x = do sc 1; return [] un' fn bnames appx@(App _ _) appy@(App _ _) = unApp fn bnames appx appy -- = uplus (unApp fn bnames appx appy) -- (unifyTmpFail appx appy) -- take the whole lot un' fn bnames x (Bind n (Lam t) (App y (P Bound n' _))) | n == n' = un' False bnames x y un' fn bnames (Bind n (Lam t) (App x (P Bound n' _))) y | n == n' = un' False bnames x y un' fn bnames x (Bind n (Lam t) (App y (V 0))) = un' False bnames x y un' fn bnames (Bind n (Lam t) (App x (V 0))) y = un' False bnames x y -- un' fn bnames (Bind x (PVar _) sx) (Bind y (PVar _) sy) -- = un' False ((x,y):bnames) sx sy -- un' fn bnames (Bind x (PVTy _) sx) (Bind y (PVTy _) sy) -- = un' False ((x,y):bnames) sx sy -- f D unifies with t -> D. This is dubious, but it helps with type -- class resolution for type classes over functions. un' fn bnames (App f x) (Bind n (Pi t) y) | noOccurrence n y && x == y = un' False bnames f (Bind (MN 0 "uv") (Lam (TType (UVar 0))) (Bind n (Pi t) (V 1))) un' fn bnames (Bind x bx sx) (Bind y by sy) = do h1 <- uB bnames bx by h2 <- un' False ((x,y):bnames) sx sy combine bnames h1 h2 un' fn bnames x y | OK True <- convEq' ctxt x y = do sc 1; return [] | otherwise = do UI s f <- get let r = recoverable x y let err = CantUnify r topx topy (CantUnify r x y (Msg "") [] s) (errEnv env) s if (not r) then lift $ tfail err else do put (UI s ((x, y, env, err) : f)) return [] -- lift $ tfail err unApp fn bnames appx@(App fx ax) appy@(App fy ay) | (injective fx && injective fy) || (injective fx && rigid appx && metavarApp appy) || (injective fy && rigid appy && metavarApp appx) || (injective fx && metavarApp fy && ax == ay) || (injective fy && metavarApp fx && ax == ay) = do let (headx, _) = unApply fx let (heady, _) = unApply fy -- fail quickly if the heads are disjoint checkHeads headx heady -- if True then -- (injective fx || injective fy || fx == fy) then -- if (injective fx && metavarApp appy) || -- (injective fy && metavarApp appx) || -- (injective fx && injective fy) || fx == fy uplus (do hf <- un' True bnames fx fy let ax' = hnormalise hf ctxt env (substNames hf ax) let ay' = hnormalise hf ctxt env (substNames hf ay) ha <- un' False bnames ax' ay' sc 1 combine bnames hf ha) (do ha <- un' False bnames ax ay let fx' = hnormalise ha ctxt env (substNames ha fx) let fy' = hnormalise ha ctxt env (substNames ha fy) hf <- un' False bnames fx' fy' sc 1 combine bnames hf ha) | otherwise = -- trace (show (appx, appy, injective fx, metavarApp appy, sameArgStruct appx appy)) $ do let (headx, argsx) = unApply appx let (heady, argsy) = unApply appy -- traceWhen (headx == heady) (show (appx, appy)) $ uplus ( if (length argsx == length argsy && ((headx == heady && inenv headx) || (argsx == argsy) || (and (zipWith sameStruct (headx:argsx) (heady:argsy))))) then -- (notFn headx && notFn heady))) then do uf <- un' True bnames headx heady failed <- errors if (not failed) then unArgs uf argsx argsy else return [] else -- trace ("TMPFAIL " ++ show (appx, appy, injective appx, injective appy)) $ unifyTmpFail appx appy) (unifyTmpFail appx appy) -- whole application fails where hnormalise [] _ _ t = t hnormalise ns ctxt env t = normalise ctxt env t checkHeads (P (DCon _ _) x _) (P (DCon _ _) y _) | x /= y = unifyFail appx appy checkHeads (P (TCon _ _) x _) (P (TCon _ _) y _) | x /= y = unifyFail appx appy checkHeads (P (DCon _ _) x _) (P (TCon _ _) y _) = unifyFail appx appy checkHeads (P (TCon _ _) x _) (P (DCon _ _) y _) = unifyFail appx appy checkHeads _ _ = return [] unArgs as [] [] = return as unArgs as (x : xs) (y : ys) = do let x' = hnormalise as ctxt env (substNames as x) let y' = hnormalise as ctxt env (substNames as y) as' <- un' False bnames x' y' vs <- combine bnames as as' unArgs vs xs ys metavarApp tm = let (f, args) = unApply tm in metavar f && all (\x -> metavarApp x) args && nub args == args metavarArgs tm = let (f, args) = unApply tm in all (\x -> metavar x || inenv x) args && nub args == args metavarApp' tm = let (f, args) = unApply tm in all (\x -> pat x || metavar x) (f : args) && nub args == args sameArgStruct appx appy = let (_, ax) = unApply appx (_, ay) = unApply appy in length ax == length ay && and (zipWith sameStruct ax ay) sameStruct fapp@(App f x) gapp@(App g y) = let (f',a') = unApply fapp (g',b') = unApply gapp in (f' == g' && length a' == length b' && (injective f' || injective g')) || (sameStruct f g && sameStruct x y) sameStruct (P _ x _) (P _ y _) = True sameStruct (V i) (V j) = i == j sameStruct (Constant x) (Constant y) = True sameStruct (P _ _ _) (Constant y) = True sameStruct (Constant x) (P _ _ _) = True sameStruct (Bind n t sc) (P _ _ _) = True sameStruct (P _ _ _) (Bind n t sc) = True sameStruct (Bind n t sc) (Bind n' t' sc') = sameStruct sc sc' sameStruct _ _ = False rigid (P (DCon _ _) _ _) = True rigid (P (TCon _ _) _ _) = True rigid t@(P Ref _ _) = inenv t rigid (Constant _) = True rigid (App f a) = rigid f && rigid a rigid t = not (metavar t) metavar t = case t of P _ x _ -> (x `elem` holes || holeIn env x) && not (x `elem` dont) _ -> False pat t = case t of P _ x _ -> x `elem` holes || patIn env x _ -> False inenv t = case t of P _ x _ -> x `elem` (map fst env) _ -> False notFn t = injective t || metavar t || inenv t unifyTmpFail x y = do UI s f <- get let r = recoverable x y let err = CantUnify r topx topy (CantUnify r x y (Msg "") [] s) (errEnv env) s put (UI s ((topx, topy, env, err) : f)) return [] -- shortcut failure, if we *know* nothing can fix it unifyFail x y = do UI s f <- get let r = recoverable x y let err = CantUnify r topx topy (CantUnify r x y (Msg "") [] s) (errEnv env) s put (UI s ((topx, topy, env, err) : f)) lift $ tfail err uB bnames (Let tx vx) (Let ty vy) = do h1 <- un' False bnames tx ty h2 <- un' False bnames vx vy sc 1 combine bnames h1 h2 uB bnames (Guess tx vx) (Guess ty vy) = do h1 <- un' False bnames tx ty h2 <- un' False bnames vx vy sc 1 combine bnames h1 h2 uB bnames (Lam tx) (Lam ty) = do sc 1; un' False bnames tx ty uB bnames (Pi tx) (Pi ty) = do sc 1; un' False bnames tx ty uB bnames (Hole tx) (Hole ty) = un' False bnames tx ty uB bnames (PVar tx) (PVar ty) = un' False bnames tx ty uB bnames x y = do UI s f <- get let r = recoverable (binderTy x) (binderTy y) let err = CantUnify r topx topy (CantUnify r (binderTy x) (binderTy y) (Msg "") [] s) (errEnv env) s put (UI s ((binderTy x, binderTy y, env, err) : f)) return [] -- lift $ tfail err checkCycle ns p@(x, P _ _ _) = return [p] checkCycle ns (x, tm) | not (x `elem` freeNames tm) = checkScope ns (x, tm) | otherwise = lift $ tfail (InfiniteUnify x tm (errEnv env)) checkScope ns (x, tm) = case boundVs (envPos x 0 env) tm of [] -> return [(x, tm)] (i:_) -> lift $ tfail (UnifyScope x (fst (ns!!i)) (inst ns tm) (errEnv env)) where inst [] tm = tm inst ((n, _) : ns) tm = inst ns (substV (P Bound n Erased) tm) combineArgs bnames args = ca [] args where ca acc [] = return acc ca acc (x : xs) = do x' <- combine bnames acc x ca x' xs combine bnames as [] = return as combine bnames as ((n, t) : bs) = case lookup n as of Nothing -> combine bnames (as ++ [(n,t)]) bs Just t' -> do ns <- un' False bnames t t' -- make sure there's n mapping from n in ns let ns' = filter (\ (x, _) -> x/=n) ns sc 1 combine bnames as (ns' ++ bs) boundVs :: Int -> Term -> [Int] boundVs i (V j) | j <= i = [] | otherwise = [j] boundVs i (Bind n b sc) = boundVs (i + 1) sc boundVs i (App f x) = let fs = boundVs i f xs = boundVs i x in nub (fs ++ xs) boundVs i _ = [] envPos x i [] = 0 envPos x i ((y, _) : ys) | x == y = i | otherwise = envPos x (i + 1) ys -- If there are any clashes of constructors, deem it unrecoverable, otherwise some -- more work may help. -- FIXME: Depending on how overloading gets used, this may cause problems. Better -- rethink overloading properly... recoverable (P (DCon _ _) x _) (P (DCon _ _) y _) | x == y = True | otherwise = False recoverable (P (TCon _ _) x _) (P (TCon _ _) y _) | x == y = True | otherwise = False recoverable (Constant _) (P (DCon _ _) y _) = False recoverable (P (DCon _ _) x _) (Constant _) = False recoverable (Constant _) (P (TCon _ _) y _) = False recoverable (P (TCon _ _) x _) (Constant _) = False recoverable (P (DCon _ _) x _) (P (TCon _ _) y _) = False recoverable (P (TCon _ _) x _) (P (DCon _ _) y _) = False recoverable p@(Constant _) (App f a) = recoverable p f recoverable (App f a) p@(Constant _) = recoverable f p recoverable p@(P _ n _) (App f a) = recoverable p f -- recoverable (App f a) p@(P _ _ _) = recoverable f p recoverable (App f a) (App f' a') = recoverable f f' -- && recoverable a a' recoverable _ _ = True errEnv = map (\(x, b) -> (x, binderTy b)) holeIn :: Env -> Name -> Bool holeIn env n = case lookup n env of Just (Hole _) -> True Just (Guess _ _) -> True _ -> False patIn :: Env -> Name -> Bool patIn env n = case lookup n env of Just (PVar _) -> True Just (PVTy _) -> True _ -> False