{-# LANGUAGE TupleSections #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.KnownNat.Solver (plugin) where
import Data.Maybe (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra (lookupModule, lookupName, tracePlugin)
import Class (Class, classMethods, className, classTyCon)
import FamInst (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id (idType)
import InstEnv (instanceDFunId,lookupUniqueInstEnv)
import Module (mkModuleName)
import OccName (mkTcOcc)
import Outputable (Outputable (..), (<+>), integer, text, vcat)
import Panic (panicDoc, pgmErrorDoc)
import Plugins (Plugin (..), defaultPlugin)
import PrelNames (knownNatClassName)
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs, zonkCt)
import TcRnTypes (Ct, CtEvidence (..), TcPlugin(..), TcPluginResult (..),
ctEvidence, ctEvPred, isWanted)
import TcTypeNats (typeNatAddTyCon, typeNatMulTyCon, typeNatExpTyCon)
import Type (PredTree (ClassPred), TyVar, classifyPredType, dropForAlls,
funResultTy, tyConAppTyCon_maybe, mkNumLitTy, mkTyVarTy,
mkTyConApp)
import TyCoRep (Type (..), TyLit (..))
import Var (DFunId)
data KnownNatDefs = KnownNatDefs
{ knAddDFunId :: (Class,DFunId)
, knMulDFunId :: (Class,DFunId)
, knExpDFunId :: (Class,DFunId)
}
instance Outputable KnownNatDefs where
ppr d = text "{" <+> ppr (knAddDFunId d) <+>
text "," <+> ppr (knMulDFunId d) <+>
text "," <+> ppr (knExpDFunId d) <+>
text "}"
type KnConstraint = (Ct
,Class
,KnOp
)
data KnOp
= I Integer
| V TyVar
| Add KnOp KnOp
| Mul KnOp KnOp
| Exp KnOp KnOp
instance Outputable KnOp where
ppr (I i) = integer i
ppr (V v) = ppr v
ppr (Add x y) = text "(" <+> ppr x <+> text "+" <+> ppr y <+> text ")"
ppr (Mul x y) = text "(" <+> ppr x <+> text "*" <+> ppr y <+> text ")"
ppr (Exp x y) = text "(" <+> ppr x <+> text "^" <+> ppr y <+> text ")"
plugin :: Plugin
plugin = defaultPlugin { tcPlugin = const $ Just normalisePlugin }
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-knownnat"
TcPlugin { tcPluginInit = lookupKnownNatDefs
, tcPluginSolve = solveKnownNat
, tcPluginStop = const (return ())
}
solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] [])
solveKnownNat defs givens _deriveds wanteds = do
let wanteds' = filter (isWanted . ctEvidence) wanteds
kn_wanteds = mapMaybe toKnConstraint wanteds'
case kn_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
kn_givens <- catMaybes <$> mapM (fmap toKnConstraint . zonkCt) givens
let kn_map = mapMaybe toKnEntry kn_givens
let solved = mapMaybe (constraintToEvTerm defs kn_map) kn_wanteds
return (TcPluginOk solved [])
toKnConstraint :: Ct -> Maybe KnConstraint
toKnConstraint ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
ClassPred cls [ty]
| className cls == knownNatClassName
-> ((ct,cls,) <$> toKnOp ty)
_ -> Nothing
toKnOp :: Type -> Maybe KnOp
toKnOp (LitTy (NumTyLit i)) = pure (I i)
toKnOp (TyVarTy v) = pure (V v)
toKnOp (TyConApp tc [x,y])
| tc == typeNatAddTyCon = Add <$> toKnOp x <*> toKnOp y
| tc == typeNatMulTyCon = Mul <$> toKnOp x <*> toKnOp y
| tc == typeNatExpTyCon = Exp <$> toKnOp x <*> toKnOp y
toKnOp _ = Nothing
toKnEntry :: KnConstraint -> Maybe (TyVar,KnConstraint)
toKnEntry kn@(_,_,V v) = Just (v,kn)
toKnEntry _ = Nothing
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
md <- lookupModule myModule myPackage
addDF <- look md "KnownNatAdd"
mulDF <- look md "KnownNatMul"
expDF <- look md "KnownNatExp"
return $ KnownNatDefs addDF mulDF expDF
where
look md s = do
nm <- lookupName md (mkTcOcc s)
cls <- tcLookupClass nm
ienv <- getInstEnvs
case lookupUniqueInstEnv ienv cls [mkNumLitTy 0, mkNumLitTy 0] of
Right (inst, _) -> return (cls,instanceDFunId inst)
Left err ->
pgmErrorDoc "Initialising GHC.TypeLits.KnownNat.Solver failed"
(vcat [text "Cannot find: " <+> text s
,text "Reason: "
,err
])
myModule = mkModuleName "GHC.TypeLits.KnownNat"
myPackage = fsLit "ghc-typelits-knownnat"
reifyOp :: KnOp -> Type
reifyOp (I i) = mkNumLitTy i
reifyOp (V v) = mkTyVarTy v
reifyOp (Add x y) = mkTyConApp typeNatAddTyCon [reifyOp x, reifyOp y]
reifyOp (Mul x y) = mkTyConApp typeNatMulTyCon [reifyOp x, reifyOp y]
reifyOp (Exp x y) = mkTyConApp typeNatExpTyCon [reifyOp x, reifyOp y]
constraintToEvTerm :: KnownNatDefs -> [(TyVar,KnConstraint)] -> KnConstraint
-> Maybe (EvTerm,Ct)
constraintToEvTerm defs kn_map (ct,cls,op) = (,ct) <$> go op
where
go (I i) = makeLitDict cls (mkNumLitTy i) i
go (V v) = case lookup v kn_map of
Just (ct',_,_) -> let ct_ev = ctEvidence ct'
evT = ctev_evar ct_ev
in Just (EvId evT)
Nothing -> Nothing
go e = do
let (x,y,df) = case e of
Add x' y' -> (x',y',knAddDFunId defs)
Mul x' y' -> (x',y',knMulDFunId defs)
Exp x' y' -> (x',y',knExpDFunId defs)
_ -> panicDoc "GHC.TypeLits.KnownNat.Solver: not an op" (ppr e)
x' <- go x
y' <- go y
makeOpDict df cls (reifyOp x) (reifyOp y) (reifyOp e) x' y'
makeOpDict :: (Class,DFunId)
-> Class
-> Type
-> Type
-> Type
-> EvTerm
-> EvTerm
-> Maybe EvTerm
makeOpDict (opCls,dfid) knCls x y z xEv yEv
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) [x,y]
, [ op_meth ] <- classMethods opCls
, Just op_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep [z]
, let dfun_inst = EvDFunApp dfid [x,y] [xEv,yEv]
op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
makeLitDict clas ty i
| Just (_, co_dict) <- tcInstNewTyCon_maybe (classTyCon clas) [ty]
, [ meth ] <- classMethods clas
, Just tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType meth
, Just (_, co_rep) <- tcInstNewTyCon_maybe tcRep [ty]
, let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
= Just ev_tm
| otherwise
= Nothing