#if __GLASGOW_HASKELL__ < 711
#endif
module GHC.TypeLits.Normalise
( plugin )
where
import Control.Arrow (second)
import Data.IORef (IORef, newIORef,readIORef, modifyIORef)
import Data.List (intersect)
import Data.Maybe (catMaybes, mapMaybe)
import GHC.TcPluginM.Extra (tracePlugin)
#if __GLASGOW_HASKELL__ < 711
import GHC.TcPluginM.Extra (evByFiat)
#endif
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins (Plugin (..), defaultPlugin)
import TcEvidence (EvTerm (..))
import TcPluginM (TcPluginM, tcPluginIO, tcPluginTrace, zonkCt)
import TcRnTypes (Ct, TcPlugin (..), TcPluginResult(..), ctEvidence, ctEvPred,
ctPred, isWanted, mkNonCanonical)
import Type (EqRel (NomEq), Kind, PredTree (EqPred), PredType, Type, TyVar,
classifyPredType, getEqPredTys, mkTyVarTy)
import TysWiredIn (typeNatKind)
#if __GLASGOW_HASKELL__ >= 711
import Coercion (CoercionHole, Role (..), mkForAllCos, mkHoleCo, mkInstCo,
mkNomReflCo, mkUnivCo)
import TcPluginM (newCoercionHole, newFlexiTyVar)
import TcRnTypes (CtEvidence (..), TcEvDest (..), ctLoc)
import TyCoRep (UnivCoProvenance (..))
import Type (mkPrimEqPred)
import TcType (typeKind)
#else
import TcType (mkEqPred, typeKind)
import GHC.TcPluginM.Extra (newWantedWithProvenance, failWithProvenace)
#endif
#if __GLASGOW_HASKELL__ >= 711
import TyCoRep (Type (..))
#else
import TypeRep (Type (..))
#endif
import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,
typeNatSubTyCon)
import GHC.Extra.Instances ()
import GHC.TypeLits.Normalise.Unify
plugin :: Plugin
plugin = defaultPlugin { tcPlugin = const $ Just normalisePlugin }
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-natnormalise"
TcPlugin { tcPluginInit = tcPluginIO $ newIORef []
, tcPluginSolve = decideEqualSOP
, tcPluginStop = const (return ())
}
decideEqualSOP :: IORef [Ct] -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
decideEqualSOP _ _givens _deriveds [] = return (TcPluginOk [] [])
decideEqualSOP discharged givens _deriveds wanteds = do
let wanteds' = filter (isWanted . ctEvidence) wanteds
let unit_wanteds = mapMaybe toNatEquality wanteds'
case unit_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
unit_givens <- mapMaybe toNatEquality <$> mapM zonkCt givens
sr <- simplifyNats (unit_givens ++ unit_wanteds)
tcPluginTrace "normalised" (ppr sr)
case sr of
Simplified _subst evs -> do
let solved = filter (isWanted . ctEvidence . (\(_,x,_) -> x)) evs
discharedWanteds <- tcPluginIO (readIORef discharged)
let existingWanteds = wanteds' ++ discharedWanteds
(solved',newWanteds) <- (second concat . unzip . catMaybes) <$>
mapM (evItemToCt existingWanteds) solved
tcPluginIO (modifyIORef discharged (++ newWanteds))
return (TcPluginOk solved' newWanteds)
#if __GLASGOW_HASKELL__ >= 711
Impossible eq -> return (TcPluginContradiction [fromNatEquality eq])
#else
Impossible eq -> failWithProvenace (fromNatEquality eq)
#endif
evItemToCt :: [Ct]
-> (EvTerm,Ct,CoreUnify CoreNote)
-> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
evItemToCt existingWanteds (ev,ct,subst)
| null newWanteds = return (Just ((ev,ct),[]))
| otherwise = do
newWanteds' <- catMaybes <$> mapM (substItemToCt existingWanteds) newWanteds
if length newWanteds == length newWanteds'
then return (Just ((ev,ct),newWanteds'))
else return Nothing
where
#if __GLASGOW_HASKELL__ >= 711
newWanteds = filter (isWanted . ctEvidence . snd . siNote) subst
#else
newWanteds = filter (isWanted . ctEvidence . siNote) subst
#endif
substItemToCt :: [Ct]
-> UnifyItem TyVar Type CoreNote
-> TcPluginM (Maybe Ct)
substItemToCt existingWanteds si
| predicate `notElem` wantedPreds
, predicateS `notElem` wantedPreds
#if __GLASGOW_HASKELL__ >= 711
= return (Just (mkNonCanonical (CtWanted predicate (HoleDest ev) (ctLoc ct))))
#else
= Just <$> mkNonCanonical <$> newWantedWithProvenance (ctEvidence ct) predicate
#endif
| otherwise
= return Nothing
where
predicate = unifyItemToPredType si
(ty1,ty2) = getEqPredTys predicate
#if __GLASGOW_HASKELL__ >= 711
predicateS = mkPrimEqPred ty2 ty1
((ev,_,_),ct) = siNote si
#else
predicateS = mkEqPred ty2 ty1
ct = siNote si
#endif
wantedPreds = map ctPred existingWanteds
unifyItemToPredType :: UnifyItem TyVar Type a -> PredType
unifyItemToPredType ui =
#if __GLASGOW_HASKELL__ >= 711
mkPrimEqPred ty1 ty2
#else
mkEqPred ty1 ty2
#endif
where
ty1 = case ui of
SubstItem {..} -> mkTyVarTy siVar
UnifyItem {..} -> reifySOP siLHS
ty2 = case ui of
SubstItem {..} -> reifySOP siSOP
UnifyItem {..} -> reifySOP siRHS
type NatEquality = (Ct,CoreSOP,CoreSOP)
fromNatEquality :: NatEquality -> Ct
fromNatEquality (ct, _, _) = ct
#if __GLASGOW_HASKELL__ >= 711
type CoreNote = ((CoercionHole,TyVar,PredType), Ct)
#else
type CoreNote = Ct
#endif
data SimplifyResult
= Simplified (CoreUnify CoreNote) [(EvTerm,Ct,CoreUnify CoreNote)]
| Impossible NatEquality
instance Outputable SimplifyResult where
ppr (Simplified subst evs) = text "Simplified" $$ ppr subst $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq
simplifyNats :: [NatEquality]
-> TcPluginM SimplifyResult
simplifyNats eqs =
tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] eqs
where
simples :: CoreUnify CoreNote -> [Maybe (EvTerm, Ct, CoreUnify CoreNote)] -> [NatEquality]
-> [NatEquality] -> TcPluginM SimplifyResult
simples subst evs _xs [] = return (Simplified subst (catMaybes evs))
simples subst evs xs (eq@(ct,u,v):eqs') = do
ur <- unifyNats ct (substsSOP subst u) (substsSOP subst v)
tcPluginTrace "unifyNats result" (ppr ur)
case ur of
#if __GLASGOW_HASKELL__ >= 711
Win -> simples subst (((,,) <$> evMagic ct [] <*> pure ct <*> pure []):evs) []
(xs ++ eqs')
Lose -> return (Impossible eq)
Draw [] -> simples subst evs (eq:xs) eqs'
Draw subst' -> do
newEvs <- mapM (\si -> (,,) <$> newCoercionHole
<*> newFlexiTyVar typeNatKind
<*> pure (unifyItemToPredType si))
subst'
let subst'' = zipWith (\si ev -> si {siNote = (ev,siNote si)})
subst' newEvs
simples (substsSubst subst'' subst ++ subst'')
(((,,) <$> evMagic ct newEvs <*> pure ct <*> pure subst''):evs)
[] (xs ++ eqs')
#else
Win -> simples subst (((,,) <$> evMagic ct <*> pure ct <*> pure []):evs) []
(xs ++ eqs')
Lose -> return (Impossible eq)
Draw [] -> simples subst evs (eq:xs) eqs'
Draw subst' -> do
simples (substsSubst subst' subst ++ subst')
(((,,) <$> evMagic ct <*> pure ct <*> pure subst'):evs)
[] (xs ++ eqs')
#endif
toNatEquality :: Ct -> Maybe NatEquality
toNatEquality ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
-> go t1 t2
_ -> Nothing
where
go (TyConApp tc xs) (TyConApp tc' ys)
| tc == tc'
, null ([tc,tc'] `intersect` [typeNatAddTyCon,typeNatSubTyCon
,typeNatMulTyCon,typeNatExpTyCon])
= case filter (uncurry (/=)) (zip xs ys) of
[(x,y)] | isNatKind (typeKind x) && isNatKind (typeKind y)
-> Just (ct, normaliseNat x, normaliseNat y)
_ -> Nothing
go x y
| isNatKind (typeKind x) && isNatKind (typeKind y)
= Just (ct,normaliseNat x,normaliseNat y)
| otherwise
= Nothing
isNatKind :: Kind -> Bool
isNatKind = (== typeNatKind)
#if __GLASGOW_HASKELL__ >= 711
evMagic :: Ct -> [(CoercionHole, TyVar, PredType)] -> Maybe EvTerm
evMagic ct evs = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 ->
let ctEv = mkUnivCo (PluginProv "ghc-typelits-natnormalise") Nominal t1 t2
(holes,tvs,preds) = unzip3 evs
holeEvs = zipWith (\h p -> uncurry (mkHoleCo h Nominal) (getEqPredTys p))
holes preds
natReflCo = mkNomReflCo typeNatKind
forallEv = mkForAllCos (map (,natReflCo) tvs) ctEv
finalEv = foldl mkInstCo forallEv holeEvs
in Just (EvCoercion finalEv)
_ -> Nothing
#else
evMagic :: Ct -> Maybe EvTerm
evMagic ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> Just (evByFiat "ghc-typelits-natnormalise" t1 t2)
_ -> Nothing
#endif