module GHC.TypeLits.Extra.Solver.Unify
( ExtraDefs (..)
, UnifyResult (..)
, normaliseNat
, unifyExtra
)
where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Function (on)
import GHC.TypeLits.Normalise.Unify (CType (..))
import Outputable (Outputable (..), ($$), text)
import TcPluginM (TcPluginM, matchFam, tcPluginTrace)
import TcRnMonad (Ct)
import TcTypeNats (typeNatExpTyCon)
import Type (TyVar, coreView, mkNumLitTy, mkTyConApp, mkTyVarTy)
import TyCon (TyCon)
import TyCoRep (Type (..), TyLit (..))
import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
import GHC.TypeLits.Extra.Solver.Operations
data ExtraDefs = ExtraDefs
{ divTyCon :: TyCon
, modTyCon :: TyCon
, flogTyCon :: TyCon
, clogTyCon :: TyCon
, logTyCon :: TyCon
, gcdTyCon :: TyCon
, lcmTyCon :: TyCon
}
normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM ExtraOp
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
normaliseNat _ (TyVarTy v) = pure (V v)
normaliseNat _ (LitTy (NumTyLit i)) = pure (I i)
normaliseNat defs (TyConApp tc [x,y])
| tc == divTyCon defs = do x' <- normaliseNat defs x
y' <- normaliseNat defs y
MaybeT (return (mergeDiv x' y'))
| tc == modTyCon defs = do x' <- normaliseNat defs x
y' <- normaliseNat defs y
MaybeT (return (mergeMod x' y'))
| tc == flogTyCon defs = do x' <- normaliseNat defs x
y' <- normaliseNat defs y
MaybeT (return (mergeFLog x' y'))
| tc == clogTyCon defs = do x' <- normaliseNat defs x
y' <- normaliseNat defs y
MaybeT (return (mergeCLog x' y'))
| tc == logTyCon defs = do x' <- normaliseNat defs x
y' <- normaliseNat defs y
MaybeT (return (mergeLog x' y'))
| tc == gcdTyCon defs = mergeGCD <$> normaliseNat defs x
<*> normaliseNat defs y
| tc == lcmTyCon defs = mergeLCM <$> normaliseNat defs x
<*> normaliseNat defs y
| tc == typeNatExpTyCon = mergeExp <$> normaliseNat defs x
<*> normaliseNat defs y
normaliseNat defs (TyConApp tc tys) = do
tys' <- mapM (fmap (reifyEOP defs) . normaliseNat defs) tys
tyM <- lift (matchFam tc tys')
case tyM of
Just (_,ty) -> normaliseNat defs ty
_ -> return (C (CType (TyConApp tc tys)))
normaliseNat _ t = return (C (CType t))
data UnifyResult
= Win
| Lose
| Draw
instance Outputable UnifyResult where
ppr Win = text "Win"
ppr Lose = text "Lose"
ppr Draw = text "Draw"
unifyExtra :: Ct -> ExtraOp -> ExtraOp -> TcPluginM UnifyResult
unifyExtra ct u v = do
tcPluginTrace "unifyExtra" (ppr ct $$ ppr u $$ ppr v)
return (unifyExtra' ct u v)
unifyExtra' :: Ct -> ExtraOp -> ExtraOp -> UnifyResult
unifyExtra' _ u v
| eqFV u v = if u == v then Win
else if containsConstants u || containsConstants v
then Draw
else Lose
| otherwise = Draw
fvOP :: ExtraOp -> UniqSet TyVar
fvOP (I _) = emptyUniqSet
fvOP (V v) = unitUniqSet v
fvOP (C _) = emptyUniqSet
fvOP (Div x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Mod x y) = fvOP x `unionUniqSets` fvOP y
fvOP (FLog x y) = fvOP x `unionUniqSets` fvOP y
fvOP (CLog x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Log x y) = fvOP x `unionUniqSets` fvOP y
fvOP (GCD x y) = fvOP x `unionUniqSets` fvOP y
fvOP (LCM x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Exp x y) = fvOP x `unionUniqSets` fvOP y
eqFV :: ExtraOp -> ExtraOp -> Bool
eqFV = (==) `on` fvOP
reifyEOP :: ExtraDefs -> ExtraOp -> Type
reifyEOP _ (I i) = mkNumLitTy i
reifyEOP _ (V v) = mkTyVarTy v
reifyEOP _ (C (CType c)) = c
reifyEOP defs (Div x y) = mkTyConApp (divTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Mod x y) = mkTyConApp (modTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (CLog x y) = mkTyConApp (clogTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (FLog x y) = mkTyConApp (flogTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Log x y) = mkTyConApp (logTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (GCD x y) = mkTyConApp (gcdTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (LCM x y) = mkTyConApp (lcmTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Exp x y) = mkTyConApp typeNatExpTyCon [reifyEOP defs x
,reifyEOP defs y]
containsConstants :: ExtraOp -> Bool
containsConstants (I _) = False
containsConstants (V _) = False
containsConstants (C _) = True
containsConstants (Div x y) = containsConstants x || containsConstants y
containsConstants (Mod x y) = containsConstants x || containsConstants y
containsConstants (FLog x y) = containsConstants x || containsConstants y
containsConstants (CLog x y) = containsConstants x || containsConstants y
containsConstants (Log x y) = containsConstants x || containsConstants y
containsConstants (GCD x y) = containsConstants x || containsConstants y
containsConstants (LCM x y) = containsConstants x || containsConstants y
containsConstants (Exp x y) = containsConstants x || containsConstants y