module GHC.TypeLits.Normalise.Unify
(
CoreSOP
, normaliseNat
, reifySOP
, UnifyItem (..)
, TyUnify
, CoreUnify
, substsSOP
, substsSubst
, UnifyResult (..)
, unifyNats
, unifiers
, fvSOP
)
where
import Data.Function (on)
import Data.List ((\\), intersect)
import Outputable (Outputable (..), (<+>), ($$), text)
import TcPluginM (TcPluginM, tcPluginTrace)
import TcRnMonad (Ct, ctEvidence, isGiven)
import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,
typeNatSubTyCon)
import Type (TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, tcView)
import TypeRep (Type (..), TyLit (..))
import UniqSet (UniqSet, unionManyUniqSets, emptyUniqSet, unionUniqSets,
unitUniqSet)
import GHC.Extra.Instances ()
import GHC.TypeLits.Normalise.SOP
import GHC.TypeLits (Nat)
type CoreSOP = SOP TyVar Type
type CoreProduct = Product TyVar Type
type CoreSymbol = Symbol TyVar Type
normaliseNat :: Type -> CoreSOP
normaliseNat ty | Just ty1 <- tcView ty = normaliseNat ty1
normaliseNat (TyVarTy v) = S [P [V v]]
normaliseNat (LitTy (NumTyLit i)) = S [P [I i]]
normaliseNat (TyConApp tc [x,y])
| tc == typeNatAddTyCon = mergeSOPAdd (normaliseNat x) (normaliseNat y)
| tc == typeNatSubTyCon = mergeSOPAdd (normaliseNat x)
(mergeSOPMul (S [P [I (1)]])
(normaliseNat y))
| tc == typeNatMulTyCon = mergeSOPMul (normaliseNat x) (normaliseNat y)
| tc == typeNatExpTyCon = normaliseExp (normaliseNat x) (normaliseNat y)
normaliseNat t = S [P [C t]]
reifySOP :: CoreSOP -> Type
reifySOP = combineP . map negateP . unS
where
negateP :: CoreProduct -> Either CoreProduct CoreProduct
negateP (P ((I i):ps)) | i < 0 = Left (P ps)
negateP ps = Right ps
combineP :: [Either CoreProduct CoreProduct] -> Type
combineP [] = mkNumLitTy 0
combineP [p] = either (\p' -> mkTyConApp typeNatSubTyCon
[mkNumLitTy 0, reifyProduct p'])
reifyProduct p
combineP (p:ps) = let es = combineP ps
in either (\x -> mkTyConApp typeNatSubTyCon
[es, reifyProduct x])
(\x -> mkTyConApp typeNatAddTyCon
[reifyProduct x, es])
p
reifyProduct :: CoreProduct -> Type
reifyProduct = foldr1 (\t1 t2 -> mkTyConApp typeNatMulTyCon [t1,t2])
. map reifySymbol . unP
reifySymbol :: CoreSymbol -> Type
reifySymbol (I i) = mkNumLitTy i
reifySymbol (C c) = c
reifySymbol (V v) = mkTyVarTy v
reifySymbol (E s p) = mkTyConApp typeNatExpTyCon [reifySOP s,reifyProduct p]
type CoreUnify = TyUnify TyVar Type Ct
type TyUnify v c n = [UnifyItem v c n]
data UnifyItem v c n = SubstItem { siVar :: v
, siSOP :: SOP v c
, siNote :: n
}
| UnifyItem { siLHS :: SOP v c
, siRHS :: SOP v c
, siNote :: n
}
instance (Outputable v, Outputable c) => Outputable (UnifyItem v c n) where
ppr (SubstItem {..}) = ppr siVar <+> text " := " <+> ppr siSOP
ppr (UnifyItem {..}) = ppr siLHS <+> text " :~ " <+> ppr siRHS
substsSOP :: (Ord v, Ord c) => TyUnify v c n -> SOP v c -> SOP v c
substsSOP [] u = u
substsSOP ((SubstItem {..}):s) u = substsSOP s (substSOP siVar siSOP u)
substsSOP ((UnifyItem {}):s) u = substsSOP s u
substSOP :: (Ord v, Ord c) => v -> SOP v c -> SOP v c -> SOP v c
substSOP tv e = foldr1 mergeSOPAdd . map (substProduct tv e) . unS
substProduct :: (Ord v, Ord c) => v -> SOP v c -> Product v c -> SOP v c
substProduct tv e = foldr1 mergeSOPMul . map (substSymbol tv e) . unP
substSymbol :: (Ord v, Ord c) => v -> SOP v c -> Symbol v c -> SOP v c
substSymbol _ _ s@(I _) = S [P [s]]
substSymbol _ _ s@(C _) = S [P [s]]
substSymbol tv e (V tv')
| tv == tv' = e
| otherwise = S [P [V tv']]
substSymbol tv e (E s p) = normaliseExp (substSOP tv e s) (substProduct tv e p)
substsSubst :: (Ord v, Ord c) => TyUnify v c n -> TyUnify v c n -> TyUnify v c n
substsSubst s = map subt
where
subt si@(SubstItem {..}) = si {siSOP = substsSOP s siSOP}
subt si@(UnifyItem {..}) = si {siLHS = substsSOP s siLHS, siRHS = substsSOP s siRHS}
data UnifyResult
= Win
| Lose
| Draw CoreUnify
instance Outputable UnifyResult where
ppr Win = text "Win"
ppr (Draw subst) = text "Draw" <+> ppr subst
ppr Lose = text "Lose"
unifyNats :: Ct -> CoreSOP -> CoreSOP -> TcPluginM UnifyResult
unifyNats ct u v = do
tcPluginTrace "unifyNats" (ppr ct $$ ppr u $$ ppr v)
return (unifyNats' ct u v)
unifyNats' :: Ct -> CoreSOP -> CoreSOP -> UnifyResult
unifyNats' ct u v
| eqFV u v
, not (containsConstants u)
, not (containsConstants v)
= if u == v then Win else Lose
| otherwise
= Draw (unifiers ct u v)
unifiers :: Ct -> CoreSOP -> CoreSOP -> CoreUnify
unifiers ct (S [P [V x]]) s
| isGiven (ctEvidence ct) = [SubstItem x s ct]
| otherwise = []
unifiers ct s (S [P [V x]])
| isGiven (ctEvidence ct) = [SubstItem x s ct]
| otherwise = []
unifiers _ (S [P [C _]]) _ = []
unifiers _ _ (S [P [C _]]) = []
unifiers ct u v = unifiers' ct u v
unifiers' :: Ct -> CoreSOP -> CoreSOP -> CoreUnify
unifiers' ct (S [P [V x]]) (S []) = [SubstItem x (S [P [I 0]]) ct]
unifiers' ct (S []) (S [P [V x]]) = [SubstItem x (S [P [I 0]]) ct]
unifiers' ct (S [P [V x]]) s = [SubstItem x s ct]
unifiers' ct s (S [P [V x]]) = [SubstItem x s ct]
unifiers' ct s1@(S [P [C _]]) s2 = [UnifyItem s1 s2 ct]
unifiers' ct s1 s2@(S [P [C _]]) = [UnifyItem s1 s2 ct]
unifiers' ct (S [P ((I _):ps)]) (S [P [I 0]]) = unifiers' ct (S [P ps]) (S [P [I 0]])
unifiers' ct (S [P [I 0]]) (S [P ((I _):ps)]) = unifiers' ct (S [P ps]) (S [P [I 0]])
unifiers' ct (S [P ps1]) (S [P ps2])
| null psx = []
| otherwise = unifiers' ct (S [P ps1'']) (S [P ps2''])
where
ps1' = ps1 \\ psx
ps2' = ps2 \\ psx
ps1'' | null ps1' = [I 1]
| otherwise = ps1'
ps2'' | null ps2' = [I 1]
| otherwise = ps2'
psx = intersect ps1 ps2
unifiers' ct (S ((P [I i]):ps1)) (S ((P [I j]):ps2))
| i < j = unifiers' ct (S ps1) (S ((P [I (ji)]):ps2))
| i > j = unifiers' ct (S ((P [I (ij)]):ps1)) (S ps2)
unifiers' ct (S ps1) (S ps2)
| null psx = []
| otherwise = unifiers' ct (S ps1'') (S ps2'')
where
ps1' = ps1 \\ psx
ps2' = ps2 \\ psx
ps1'' | null ps1' = [P [I 0]]
| otherwise = ps1'
ps2'' | null ps2' = [P [I 0]]
| otherwise = ps2'
psx = intersect ps1 ps2
fvSOP :: CoreSOP -> UniqSet TyVar
fvSOP = unionManyUniqSets . map fvProduct . unS
fvProduct :: CoreProduct -> UniqSet TyVar
fvProduct = unionManyUniqSets . map fvSymbol . unP
fvSymbol :: CoreSymbol -> UniqSet TyVar
fvSymbol (I _) = emptyUniqSet
fvSymbol (C _) = emptyUniqSet
fvSymbol (V v) = unitUniqSet v
fvSymbol (E s p) = fvSOP s `unionUniqSets` fvProduct p
eqFV :: CoreSOP -> CoreSOP -> Bool
eqFV = (==) `on` fvSOP
containsConstants :: CoreSOP -> Bool
containsConstants = any (any (\c -> case c of {(C _) -> True; _ -> False}) . unP) . unS