{-| Copyright : (C) 2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij A type checker plugin for GHC that can derive \"complex\" @KnownNat@ constraints from other simple/variable @KnownNat@ constraints. i.e. without this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@ constraint in the type signature of the following function: @ f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2)) @ Using the plugin you can omit the @KnownNat (n+2)@ constraint: @ f :: forall n . KnownNat n => Proxy n -> Integer f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2)) @ The plugin can only derive @KnownNat@ constraints consisting of: * Type-level naturals * Type variables * Applications of the arithmetic expression: @{+,*,^}@ i.e. it /cannot/ derive a @KnownNat (n-1)@ constraint from a @KnownNat n@ constraint To use the plugin, add the @ OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver @ Pragma to the header of your file. -} {-# LANGUAGE TupleSections #-} {-# LANGUAGE Trustworthy #-} {-# OPTIONS_HADDOCK show-extensions #-} module GHC.TypeLits.KnownNat.Solver (plugin) where -- external import Data.Maybe (catMaybes,mapMaybe) import GHC.TcPluginM.Extra (lookupModule, lookupName, tracePlugin) -- GHC API import Class (Class, classMethods, className, classTyCon) import FamInst (tcInstNewTyCon_maybe) import FastString (fsLit) import Id (idType) import InstEnv (instanceDFunId,lookupUniqueInstEnv) import Module (mkModuleName) import OccName (mkTcOcc) import Outputable (Outputable (..), (<+>), integer, text, vcat) import Panic (panicDoc, pgmErrorDoc) import Plugins (Plugin (..), defaultPlugin) import PrelNames (knownNatClassName) import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo) import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs, zonkCt) import TcRnTypes (Ct, CtEvidence (..), TcPlugin(..), TcPluginResult (..), ctEvidence, ctEvPred, isWanted) import TcTypeNats (typeNatAddTyCon, typeNatMulTyCon, typeNatExpTyCon) import Type (PredTree (ClassPred), TyVar, classifyPredType, dropForAlls, funResultTy, tyConAppTyCon_maybe, mkNumLitTy, mkTyVarTy, mkTyConApp) import TyCoRep (Type (..), TyLit (..)) import Var (DFunId) -- | Classes and instances from "GHC.TypeLits.KnownNat" data KnownNatDefs = KnownNatDefs { knAddDFunId :: (Class,DFunId) -- ^ KnownNatAdd class and its only instance , knMulDFunId :: (Class,DFunId) -- ^ KnownNatMul class and its only instance , knExpDFunId :: (Class,DFunId) -- ^ KnownNatPow class and its only instance } instance Outputable KnownNatDefs where ppr d = text "{" <+> ppr (knAddDFunId d) <+> text "," <+> ppr (knMulDFunId d) <+> text "," <+> ppr (knExpDFunId d) <+> text "}" -- | KnownNat constraints type KnConstraint = (Ct -- The constraint ,Class -- KnownNat class ,KnOp -- The argument to KnownNat ) -- | Reified argument of a KnownNat data KnOp = I Integer | V TyVar | Add KnOp KnOp | Mul KnOp KnOp | Exp KnOp KnOp instance Outputable KnOp where ppr (I i) = integer i ppr (V v) = ppr v ppr (Add x y) = text "(" <+> ppr x <+> text "+" <+> ppr y <+> text ")" ppr (Mul x y) = text "(" <+> ppr x <+> text "*" <+> ppr y <+> text ")" ppr (Exp x y) = text "(" <+> ppr x <+> text "^" <+> ppr y <+> text ")" {-| A type checker plugin for GHC that can derive \"complex\" @KnownNat@ constraints from other simple/variable @KnownNat@ constraints. i.e. without this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@ constraint in the type signature of the following function: @ f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2)) @ Using the plugin you can omit the @KnownNat (n+2)@ constraint: @ f :: forall n . KnownNat n => Proxy n -> Integer f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2)) @ The plugin can only derive @KnownNat@ constraints consisting of: * Type-level naturals * Type variables * Applications of the arithmetic expression: @{+,*,^}@. i.e. it /cannot/ derive a @KnownNat (n-1)@ constraint from a @KnownNat n@ constraint To use the plugin, add the @ OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver @ Pragma to the header of your file. -} plugin :: Plugin plugin = defaultPlugin { tcPlugin = const $ Just normalisePlugin } normalisePlugin :: TcPlugin normalisePlugin = tracePlugin "ghc-typelits-knownnat" TcPlugin { tcPluginInit = lookupKnownNatDefs , tcPluginSolve = solveKnownNat , tcPluginStop = const (return ()) } solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] []) solveKnownNat defs givens _deriveds wanteds = do -- GHC 7.10 puts deriveds with the wanteds, so filter them out let wanteds' = filter (isWanted . ctEvidence) wanteds kn_wanteds = mapMaybe toKnConstraint wanteds' case kn_wanteds of [] -> return (TcPluginOk [] []) _ -> do kn_givens <- catMaybes <$> mapM (fmap toKnConstraint . zonkCt) givens -- Make a lookup table of the [G]iven KnownNat constraints let kn_map = mapMaybe toKnEntry kn_givens -- Try to solve the wanted KnownNat constraints given the [G]iven -- KnownNat constraints let solved = mapMaybe (constraintToEvTerm defs kn_map) kn_wanteds return (TcPluginOk solved []) -- | Get the KnownNat constraints toKnConstraint :: Ct -> Maybe KnConstraint toKnConstraint ct = case classifyPredType $ ctEvPred $ ctEvidence ct of ClassPred cls [ty] | className cls == knownNatClassName -> ((ct,cls,) <$> toKnOp ty) _ -> Nothing {- | The plugin can only derive @KnownNat@ constraints consisting of: * Type-level naturals * Type variables * Applications of the arithmetic expression: @{+,*,^}@. -} toKnOp :: Type -> Maybe KnOp toKnOp (LitTy (NumTyLit i)) = pure (I i) toKnOp (TyVarTy v) = pure (V v) toKnOp (TyConApp tc [x,y]) | tc == typeNatAddTyCon = Add <$> toKnOp x <*> toKnOp y | tc == typeNatMulTyCon = Mul <$> toKnOp x <*> toKnOp y | tc == typeNatExpTyCon = Exp <$> toKnOp x <*> toKnOp y toKnOp _ = Nothing -- | Create a look-up entry for @n@ given a [G]iven @KnownNat n@ constraint. toKnEntry :: KnConstraint -> Maybe (TyVar,KnConstraint) toKnEntry kn@(_,_,V v) = Just (v,kn) toKnEntry _ = Nothing -- | Find the \"magic\" classes and instances in "GHC.TypeLits.KnownNat" lookupKnownNatDefs :: TcPluginM KnownNatDefs lookupKnownNatDefs = do md <- lookupModule myModule myPackage addDF <- look md "KnownNatAdd" mulDF <- look md "KnownNatMul" expDF <- look md "KnownNatExp" return $ KnownNatDefs addDF mulDF expDF where look md s = do nm <- lookupName md (mkTcOcc s) cls <- tcLookupClass nm ienv <- getInstEnvs case lookupUniqueInstEnv ienv cls [mkNumLitTy 0, mkNumLitTy 0] of Right (inst, _) -> return (cls,instanceDFunId inst) Left err -> pgmErrorDoc "Initialising GHC.TypeLits.KnownNat.Solver failed" (vcat [text "Cannot find: " <+> text s ,text "Reason: " ,err ]) myModule = mkModuleName "GHC.TypeLits.KnownNat" myPackage = fsLit "ghc-typelits-knownnat" -- | Convert a reified argument of a KnownNat constraint back to a type reifyOp :: KnOp -> Type reifyOp (I i) = mkNumLitTy i reifyOp (V v) = mkTyVarTy v reifyOp (Add x y) = mkTyConApp typeNatAddTyCon [reifyOp x, reifyOp y] reifyOp (Mul x y) = mkTyConApp typeNatMulTyCon [reifyOp x, reifyOp y] reifyOp (Exp x y) = mkTyConApp typeNatExpTyCon [reifyOp x, reifyOp y] -- | Try to create evidence for a wanted constraint constraintToEvTerm :: KnownNatDefs -> [(TyVar,KnConstraint)] -> KnConstraint -> Maybe (EvTerm,Ct) constraintToEvTerm defs kn_map (ct,cls,op) = (,ct) <$> go op where go (I i) = makeLitDict cls (mkNumLitTy i) i go (V v) = case lookup v kn_map of Just (ct',_,_) -> let ct_ev = ctEvidence ct' evT = ctev_evar ct_ev in Just (EvId evT) Nothing -> Nothing go e = do let (x,y,df) = case e of Add x' y' -> (x',y',knAddDFunId defs) Mul x' y' -> (x',y',knMulDFunId defs) Exp x' y' -> (x',y',knExpDFunId defs) _ -> panicDoc "GHC.TypeLits.KnownNat.Solver: not an op" (ppr e) x' <- go x y' <- go y makeOpDict df cls (reifyOp x) (reifyOp y) (reifyOp e) x' y' {- Given: * A "magic" class, and corresponding instance dictionary function, for a type-level arithmetic operation * Two KnownNat dictionaries makeOpDict instantiates the dictionary function with the KnownNat dictionaries, and coerces it to a KnownNat dictionary. i.e. for KnownNatAdd, the "magic" dictionary for addition, the coercion happens in the following steps: 1. KnownNatAdd a b -> SNatKn (a + b) 2. SNatKn (a + b) -> Integer 3. Integer -> SNat (a + b) 4. SNat (a + b) -> KnownNat (a + b) The process is mirrored for KnownNatMul, and KnownNatExp, the classes representing multiplication and exponentiation. -} makeOpDict :: (Class,DFunId) -- ^ "magic" class function and dictionary function id -> Class -- ^ KnownNat class -> Type -- ^ Type of the first argument -> Type -- ^ Type of the second argument -> Type -- ^ Type of the result -> EvTerm -- ^ KnownNat dictionary for the first argument -> EvTerm -- ^ KnownNat dictionary for the second argument -> Maybe EvTerm makeOpDict (opCls,dfid) knCls x y z xEv yEv | Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z] -- KnownNat n ~ SNat n , [ kn_meth ] <- classMethods knCls , Just kn_tcRep <- tyConAppTyCon_maybe -- SNat $ funResultTy -- SNat n $ dropForAlls -- KnownNat n => SNat n $ idType kn_meth -- forall n. KnownNat n => SNat n , Just (_, kn_co_rep) <- tcInstNewTyCon_maybe kn_tcRep [z] -- SNat n ~ Integer , Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) [x,y] -- KnownNatAdd a b ~ SNatKn (a+b) , [ op_meth ] <- classMethods opCls , Just op_tcRep <- tyConAppTyCon_maybe -- SNatKn $ funResultTy -- SNatKn (a+b) $ dropForAlls -- KnownNatAdd a b => SNatKn (a + b) $ idType op_meth -- forall a b . KnownNatAdd a b => SNatKn (a+b) , Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep [z] -- SNatKn (a+b) ~ Integer , let dfun_inst = EvDFunApp dfid [x,y] [xEv,yEv] -- KnownNatAdd a b op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep) (mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep)) -- KnownNatAdd a b ~ KnownNat (a+b) ev_tm = mkEvCast dfun_inst op_to_kn = Just ev_tm | otherwise = Nothing -- | THIS CODE IS COPIED FROM: -- https://github.com/ghc/ghc/blob/8035d1a5dc7290e8d3d61446ee4861e0b460214e/compiler/typecheck/TcInteract.hs#L1973 -- -- makeLitDict adds a coercion that will convert the literal into a dictionary -- of the appropriate type. See Note [KnownNat & KnownSymbol and EvLit] -- in TcEvidence. The coercion happens in 2 steps: -- -- Integer -> SNat n -- representation of literal to singleton -- SNat n -> KnownNat n -- singleton to dictionary makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm makeLitDict clas ty i | Just (_, co_dict) <- tcInstNewTyCon_maybe (classTyCon clas) [ty] -- co_dict :: KnownNat n ~ SNat n , [ meth ] <- classMethods clas , Just tcRep <- tyConAppTyCon_maybe -- SNat $ funResultTy -- SNat n $ dropForAlls -- KnownNat n => SNat n $ idType meth -- forall n. KnownNat n => SNat n , Just (_, co_rep) <- tcInstNewTyCon_maybe tcRep [ty] -- SNat n ~ Integer , let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep)) = Just ev_tm | otherwise = Nothing