{-# LANGUAGE TupleSections #-} {-# LANGUAGE CPP #-} {-# OPTIONS_HADDOCK show-extensions #-} {-| Copyright : (C) 2015, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij A type checker plugin for GHC that can solve /equalities/ of types of kind 'GHC.TypeLits.Nat', where these types are either: * Type-level naturals * Type variables * Applications of the arithmetic expressions @(+,-,*,^)@. It solves these equalities by normalising them to /sort-of/ 'GHC.TypeLits.Normalise.SOP.SOP' (Sum-of-Products) form, and then perform a simple syntactic equality. For example, this solver can prove the equality between: @ (x + 2)^(y + 2) @ and @ 4*x*(2 + x)^y + 4*(2 + x)^y + (2 + x)^y*x^2 @ Because the latter is actually the 'GHC.TypeLits.Normalise.SOP.SOP' normal form of the former. To use the plugin, add @ {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise \#-\} @ To the header of your file. -} module GHC.TypeLits.Normalise ( plugin ) where -- external import Data.Maybe (catMaybes, mapMaybe) -- GHC API import Coercion (Role (Nominal), mkUnivCo) import FastString (fsLit) import Outputable (Outputable (..), (<+>), ($$), text) import Plugins (Plugin (..), defaultPlugin) import TcEvidence (EvTerm (EvCoercion), TcCoercion (..)) import TcPluginM (TcPluginM, tcPluginTrace, unsafeTcPluginTcM, zonkCt) #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 711 import qualified Inst #else import qualified TcMType #endif import TcRnTypes (Ct, CtLoc, CtOrigin, TcPlugin(..), TcPluginResult(..), ctEvidence, ctEvPred, ctLoc, ctLocOrigin, isGiven, isWanted, mkNonCanonical) import TcSMonad (runTcS,newGivenEvVar) import TcType (mkEqPred, typeKind) import Type (EqRel (NomEq), Kind, PredTree (EqPred), PredType, Type, TyVar, classifyPredType, mkTyVarTy) import TysWiredIn (typeNatKind) -- internal import GHC.TypeLits.Normalise.Unify -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301 import Control.Monad (unless) import Data.IORef (readIORef) import StaticFlags (initStaticOpts, v_opt_C_ready) import TcPluginM (tcPluginIO) -- | To use the plugin, add -- -- @ -- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise \#-\} -- @ -- -- To the header of your file. plugin :: Plugin plugin = defaultPlugin { tcPlugin = const $ Just normalisePlugin } normalisePlugin :: TcPlugin normalisePlugin = TcPlugin { tcPluginInit = return () , tcPluginSolve = decideEqualSOP , tcPluginStop = const (return ()) } decideEqualSOP :: () -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult decideEqualSOP _ _givens _deriveds [] = return (TcPluginOk [] []) decideEqualSOP _ givens _deriveds wanteds = do -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301 initializeStaticFlags let unit_wanteds = mapMaybe toNatEquality wanteds case unit_wanteds of [] -> return (TcPluginOk [] []) _ -> do unit_givens <- mapMaybe toNatEquality <$> mapM zonkCt givens sr <- simplifyNats (unit_givens ++ unit_wanteds) tcPluginTrace "normalised" (ppr sr) case sr of Simplified subst evs -> TcPluginOk (filter (isWanted . ctEvidence . snd) evs) <$> mapM substItemToCt (filter (isWanted . ctEvidence . siNote) subst) Impossible eq -> return (TcPluginContradiction [fromNatEquality eq]) substItemToCt :: SubstItem TyVar Type Ct -> TcPluginM Ct substItemToCt si | isGiven (ctEvidence ct) = newSimpleGiven loc predicate (ty1,ty2) | otherwise = newSimpleWanted (ctLocOrigin loc) predicate where predicate = mkEqPred ty1 ty2 ty1 = mkTyVarTy (siVar si) ty2 = reifySOP (siSOP si) ct = siNote si loc = ctLoc ct type NatEquality = (Ct,CoreSOP,CoreSOP) fromNatEquality :: NatEquality -> Ct fromNatEquality (ct, _, _) = ct data SimplifyResult = Simplified CoreSubst [(EvTerm,Ct)] | Impossible NatEquality instance Outputable SimplifyResult where ppr (Simplified subst evs) = text "Simplified" $$ ppr subst $$ ppr evs ppr (Impossible eq) = text "Impossible" <+> ppr eq simplifyNats :: [NatEquality] -> TcPluginM SimplifyResult simplifyNats eqs = tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] eqs where simples :: CoreSubst -> [Maybe (EvTerm, Ct)] -> [NatEquality] -> [NatEquality] -> TcPluginM SimplifyResult simples subst evs _xs [] = return (Simplified subst (catMaybes evs)) simples subst evs xs (eq@(ct,u,v):eqs') = do ur <- unifyNats ct (substsSOP subst u) (substsSOP subst v) tcPluginTrace "unifyNats result" (ppr ur) case ur of Win -> simples subst (((,) <$> evMagic ct <*> pure ct):evs) [] (xs ++ eqs') Lose -> return (Impossible eq) Draw [] -> simples subst evs (eq:xs) eqs' Draw subst' -> simples (substsSubst subst' subst ++ subst') evs [eq] (xs ++ eqs') -- Extract the Nat equality constraints toNatEquality :: Ct -> Maybe NatEquality toNatEquality ct = case classifyPredType $ ctEvPred $ ctEvidence ct of EqPred NomEq t1 t2 | isNatKind (typeKind t1) || isNatKind (typeKind t1) -> Just (ct,normaliseNat t1,normaliseNat t2) _ -> Nothing where isNatKind :: Kind -> Bool isNatKind = (== typeNatKind) -- Utils newSimpleWanted :: CtOrigin -> PredType -> TcPluginM Ct #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 711 newSimpleWanted orig = fmap mkNonCanonical . unsafeTcPluginTcM . Inst.newWanted orig #else newSimpleWanted orig = unsafeTcPluginTcM . TcMType.newSimpleWanted orig #endif newSimpleGiven :: CtLoc -> PredType -> (Type,Type) -> TcPluginM Ct newSimpleGiven loc predicate (ty1,ty2)= do (ev,_) <- unsafeTcPluginTcM $ runTcS $ newGivenEvVar loc (predicate, evByFiat "units" (ty1, ty2)) return (mkNonCanonical ev) evMagic :: Ct -> Maybe EvTerm evMagic ct = case classifyPredType $ ctEvPred $ ctEvidence ct of EqPred NomEq t1 t2 -> Just (evByFiat "tylits_magic" (t1, t2)) _ -> Nothing evByFiat :: String -> (Type, Type) -> EvTerm evByFiat name (t1,t2) = EvCoercion $ TcCoercion $ mkUnivCo (fsLit name) Nominal t1 t2 -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301 initializeStaticFlags :: TcPluginM () initializeStaticFlags = tcPluginIO $ do r <- readIORef v_opt_C_ready unless r initStaticOpts