{-# LANGUAGE RecordWildCards #-} {-# OPTIONS_GHC -fno-warn-unused-imports #-} {-| Copyright : (C) 2015, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij -} module GHC.TypeLits.Normalise.Unify ( -- * 'Nat' expressions \<-\> 'SOP' terms CoreSOP , normaliseNat , reifySOP -- * Substitution on 'SOP' terms , UnifyItem (..) , TyUnify , CoreUnify , substsSOP , substsSubst -- * Find unifiers , UnifyResult (..) , unifyNats , unifiers -- * Free variables in 'SOP' terms , fvSOP ) where -- External import Data.Function (on) import Data.List ((\\), intersect) -- GHC API import Outputable (Outputable (..), (<+>), ($$), text) import TcPluginM (TcPluginM, tcPluginTrace) import TcRnMonad (Ct, ctEvidence, isGiven) import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon, typeNatSubTyCon) import Type (TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, tcView) import TypeRep (Type (..), TyLit (..)) import UniqSet (UniqSet, unionManyUniqSets, emptyUniqSet, unionUniqSets, unitUniqSet) -- Internal import GHC.Extra.Instances () -- Ord instance for Type import GHC.TypeLits.Normalise.SOP -- Used for haddock import GHC.TypeLits (Nat) -- | 'SOP' with 'TyVar' variables type CoreSOP = SOP TyVar Type type CoreProduct = Product TyVar Type type CoreSymbol = Symbol TyVar Type -- | Convert a type of /kind/ 'GHC.TypeLits.Nat' to an 'SOP' term, but -- only when the type is constructed out of: -- -- * literals -- * type variables -- * Applications of the arithmetic operators @(+,-,*,^)@ normaliseNat :: Type -> CoreSOP normaliseNat ty | Just ty1 <- tcView ty = normaliseNat ty1 normaliseNat (TyVarTy v) = S [P [V v]] normaliseNat (LitTy (NumTyLit i)) = S [P [I i]] normaliseNat (TyConApp tc [x,y]) | tc == typeNatAddTyCon = mergeSOPAdd (normaliseNat x) (normaliseNat y) | tc == typeNatSubTyCon = mergeSOPAdd (normaliseNat x) (mergeSOPMul (S [P [I (-1)]]) (normaliseNat y)) | tc == typeNatMulTyCon = mergeSOPMul (normaliseNat x) (normaliseNat y) | tc == typeNatExpTyCon = normaliseExp (normaliseNat x) (normaliseNat y) normaliseNat t = S [P [C t]] -- | Convert a 'SOP' term back to a type of /kind/ 'GHC.TypeLits.Nat' reifySOP :: CoreSOP -> Type reifySOP = combineP . map negateP . unS where negateP :: CoreProduct -> Either CoreProduct CoreProduct negateP (P ((I i):ps)) | i < 0 = Left (P ps) negateP ps = Right ps combineP :: [Either CoreProduct CoreProduct] -> Type combineP [] = mkNumLitTy 0 combineP [p] = either (\p' -> mkTyConApp typeNatSubTyCon [mkNumLitTy 0, reifyProduct p']) reifyProduct p combineP (p:ps) = let es = combineP ps in either (\x -> mkTyConApp typeNatSubTyCon [es, reifyProduct x]) (\x -> mkTyConApp typeNatAddTyCon [reifyProduct x, es]) p reifyProduct :: CoreProduct -> Type reifyProduct = foldr1 (\t1 t2 -> mkTyConApp typeNatMulTyCon [t1,t2]) . map reifySymbol . unP reifySymbol :: CoreSymbol -> Type reifySymbol (I i) = mkNumLitTy i reifySymbol (C c) = c reifySymbol (V v) = mkTyVarTy v reifySymbol (E s p) = mkTyConApp typeNatExpTyCon [reifySOP s,reifyProduct p] -- | A substitution is essentially a list of (variable, 'SOP') pairs, -- but we keep the original 'Ct' that lead to the substitution being -- made, for use when turning the substitution back into constraints. type CoreUnify = TyUnify TyVar Type Ct type TyUnify v c n = [UnifyItem v c n] data UnifyItem v c n = SubstItem { siVar :: v , siSOP :: SOP v c , siNote :: n } | UnifyItem { siLHS :: SOP v c , siRHS :: SOP v c , siNote :: n } instance (Outputable v, Outputable c) => Outputable (UnifyItem v c n) where ppr (SubstItem {..}) = ppr siVar <+> text " := " <+> ppr siSOP ppr (UnifyItem {..}) = ppr siLHS <+> text " :~ " <+> ppr siRHS -- | Apply a substitution to a single normalised 'SOP' term substsSOP :: (Ord v, Ord c) => TyUnify v c n -> SOP v c -> SOP v c substsSOP [] u = u substsSOP ((SubstItem {..}):s) u = substsSOP s (substSOP siVar siSOP u) substsSOP ((UnifyItem {}):s) u = substsSOP s u substSOP :: (Ord v, Ord c) => v -> SOP v c -> SOP v c -> SOP v c substSOP tv e = foldr1 mergeSOPAdd . map (substProduct tv e) . unS substProduct :: (Ord v, Ord c) => v -> SOP v c -> Product v c -> SOP v c substProduct tv e = foldr1 mergeSOPMul . map (substSymbol tv e) . unP substSymbol :: (Ord v, Ord c) => v -> SOP v c -> Symbol v c -> SOP v c substSymbol _ _ s@(I _) = S [P [s]] substSymbol _ _ s@(C _) = S [P [s]] substSymbol tv e (V tv') | tv == tv' = e | otherwise = S [P [V tv']] substSymbol tv e (E s p) = normaliseExp (substSOP tv e s) (substProduct tv e p) -- | Apply a substitution to a substitution substsSubst :: (Ord v, Ord c) => TyUnify v c n -> TyUnify v c n -> TyUnify v c n substsSubst s = map subt where subt si@(SubstItem {..}) = si {siSOP = substsSOP s siSOP} subt si@(UnifyItem {..}) = si {siLHS = substsSOP s siLHS, siRHS = substsSOP s siRHS} -- | Result of comparing two 'SOP' terms, returning a potential substitution -- list under which the two terms are equal. data UnifyResult = Win -- ^ Two terms are equal | Lose -- ^ Two terms are /not/ equal | Draw CoreUnify -- ^ Two terms are only equal if the given substitution holds instance Outputable UnifyResult where ppr Win = text "Win" ppr (Draw subst) = text "Draw" <+> ppr subst ppr Lose = text "Lose" -- | Given two 'SOP's @u@ and @v@, when their free variables ('fvSOP') are the -- same, then we 'Win' if @u@ and @v@ are equal, and 'Lose' otherwise. -- -- If @u@ and @v@ do not have the same free variables, we result in a 'Draw', -- ware @u@ and @v@ are only equal when the returned 'CoreSubst' holds. unifyNats :: Ct -> CoreSOP -> CoreSOP -> TcPluginM UnifyResult unifyNats ct u v = do tcPluginTrace "unifyNats" (ppr ct $$ ppr u $$ ppr v) return (unifyNats' ct u v) unifyNats' :: Ct -> CoreSOP -> CoreSOP -> UnifyResult unifyNats' ct u v | eqFV u v , not (containsConstants u) , not (containsConstants v) = if u == v then Win else Lose | otherwise = Draw (unifiers ct u v) -- | Find unifiers for two SOP terms -- -- Can find the following unifiers: -- -- @ -- t ~ a + b ==> [t := a + b] -- a + b ~ t ==> [t := a + b] -- (a + c) ~ (b + c) ==> \[a := b\] -- (2*a) ~ (2*b) ==> [a := b] -- (2 + a) ~ 5 ==> [a := 3] -- (i * a) ~ j ==> [a := div j i], when (mod j i == 0) -- @ -- -- However, given a wanted: -- -- @ -- [W] t ~ a + b -- @ -- -- this function returns @[]@, or otherwise we \"solve\" the constraint by -- finding a unifier equal to the constraint. -- -- However, given a wanted: -- -- @ -- [W] (a + c) ~ (b + c) -- @ -- -- we do return the unifier: -- -- @ -- [a := b] -- @ unifiers :: Ct -> CoreSOP -> CoreSOP -> CoreUnify unifiers ct (S [P [V x]]) s | isGiven (ctEvidence ct) = [SubstItem x s ct] | otherwise = [] unifiers ct s (S [P [V x]]) | isGiven (ctEvidence ct) = [SubstItem x s ct] | otherwise = [] unifiers _ (S [P [C _]]) _ = [] unifiers _ _ (S [P [C _]]) = [] unifiers ct u v = unifiers' ct u v unifiers' :: Ct -> CoreSOP -> CoreSOP -> CoreUnify unifiers' ct (S [P [V x]]) (S []) = [SubstItem x (S [P [I 0]]) ct] unifiers' ct (S []) (S [P [V x]]) = [SubstItem x (S [P [I 0]]) ct] unifiers' ct (S [P [V x]]) s = [SubstItem x s ct] unifiers' ct s (S [P [V x]]) = [SubstItem x s ct] unifiers' ct s1@(S [P [C _]]) s2 = [UnifyItem s1 s2 ct] unifiers' ct s1 s2@(S [P [C _]]) = [UnifyItem s1 s2 ct] -- (i * a) ~ j ==> [a := div j i] -- Where 'a' is a variable, 'i' and 'j' are integer literals, and j `mod` i == 0 unifiers' ct (S [P ((I i):ps)]) (S [P [I j]]) = case safeDiv j i of Just k -> unifiers' ct (S [P ps]) (S [P [I k]]) _ -> [] unifiers' ct (S [P [I j]]) (S [P ((I i):ps)]) = case safeDiv j i of Just k -> unifiers' ct (S [P ps]) (S [P [I k]]) _ -> [] -- (2*a) ~ (2*b) ==> [a := b] -- unifiers' ct (S [P (p:ps1)]) (S [P (p':ps2)]) -- | p == p' = unifiers' ct (S [P ps1]) (S [P ps2]) -- | otherwise = [] unifiers' ct (S [P ps1]) (S [P ps2]) | null psx = [] | otherwise = unifiers' ct (S [P ps1'']) (S [P ps2'']) where ps1' = ps1 \\ psx ps2' = ps2 \\ psx ps1'' | null ps1' = [I 1] | otherwise = ps1' ps2'' | null ps2' = [I 1] | otherwise = ps2' psx = intersect ps1 ps2 -- (2 + a) ~ 5 ==> [a := 3] unifiers' ct (S ((P [I i]):ps1)) (S ((P [I j]):ps2)) | i < j = unifiers' ct (S ps1) (S ((P [I (j-i)]):ps2)) | i > j = unifiers' ct (S ((P [I (i-j)]):ps1)) (S ps2) -- (a + c) ~ (b + c) ==> [a := b] unifiers' ct (S ps1) (S ps2) | null psx = [] | otherwise = unifiers' ct (S ps1'') (S ps2'') where ps1' = ps1 \\ psx ps2' = ps2 \\ psx ps1'' | null ps1' = [P [I 0]] | otherwise = ps1' ps2'' | null ps2' = [P [I 0]] | otherwise = ps2' psx = intersect ps1 ps2 -- | Find the 'TyVar' in a 'CoreSOP' fvSOP :: CoreSOP -> UniqSet TyVar fvSOP = unionManyUniqSets . map fvProduct . unS fvProduct :: CoreProduct -> UniqSet TyVar fvProduct = unionManyUniqSets . map fvSymbol . unP fvSymbol :: CoreSymbol -> UniqSet TyVar fvSymbol (I _) = emptyUniqSet fvSymbol (C _) = emptyUniqSet fvSymbol (V v) = unitUniqSet v fvSymbol (E s p) = fvSOP s `unionUniqSets` fvProduct p eqFV :: CoreSOP -> CoreSOP -> Bool eqFV = (==) `on` fvSOP containsConstants :: CoreSOP -> Bool containsConstants = any (any (\c -> case c of {(C _) -> True; _ -> False}) . unP) . unS safeDiv :: Integer -> Integer -> Maybe Integer safeDiv i j | j == 0 = Just 0 | otherwise = case divMod i j of (k,0) -> Just k _ -> Nothing