{-# LANGUAGE CPP #-}
module Polysemy.Plugin.Fundep.Unification where
import Data.Bool
import Data.Function (on)
import Data.Set (Set)
import qualified Data.Set as S
#if __GLASGOW_HASKELL__ >= 900
import GHC.Tc.Types.Constraint
#elif __GLASGOW_HASKELL__ >= 810
import Constraint
#else
import TcRnTypes
#endif
#if __GLASGOW_HASKELL__ >= 900
import GHC.Core.Type
import GHC.Core.Unify
import GHC.Plugins (Outputable, ppr, parens, text, (<+>))
#else
import Type
import Unify
import GhcPlugins (Outputable, ppr, parens, text, (<+>))
#endif
#if __GLASGOW_HASKELL__ >= 906
#define SUBST Subst
import GHC.Core.TyCo.Subst (SUBST)
import GHC.Core.TyCo.Compare (eqType, nonDetCmpType)
#else
#define SUBST TCvSubst
#endif
data SolveContext
=
FunctionDef (Set TyVar)
| InterpreterUse Bool (Set TyVar)
deriving (SolveContext -> SolveContext -> Bool
(SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool) -> Eq SolveContext
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SolveContext -> SolveContext -> Bool
== :: SolveContext -> SolveContext -> Bool
$c/= :: SolveContext -> SolveContext -> Bool
/= :: SolveContext -> SolveContext -> Bool
Eq, Eq SolveContext
Eq SolveContext =>
(SolveContext -> SolveContext -> Ordering)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> SolveContext)
-> (SolveContext -> SolveContext -> SolveContext)
-> Ord SolveContext
SolveContext -> SolveContext -> Bool
SolveContext -> SolveContext -> Ordering
SolveContext -> SolveContext -> SolveContext
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SolveContext -> SolveContext -> Ordering
compare :: SolveContext -> SolveContext -> Ordering
$c< :: SolveContext -> SolveContext -> Bool
< :: SolveContext -> SolveContext -> Bool
$c<= :: SolveContext -> SolveContext -> Bool
<= :: SolveContext -> SolveContext -> Bool
$c> :: SolveContext -> SolveContext -> Bool
> :: SolveContext -> SolveContext -> Bool
$c>= :: SolveContext -> SolveContext -> Bool
>= :: SolveContext -> SolveContext -> Bool
$cmax :: SolveContext -> SolveContext -> SolveContext
max :: SolveContext -> SolveContext -> SolveContext
$cmin :: SolveContext -> SolveContext -> SolveContext
min :: SolveContext -> SolveContext -> SolveContext
Ord)
instance Outputable SolveContext where
ppr :: SolveContext -> SDoc
ppr (FunctionDef Set TyCoVar
s) = SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
parens (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"FunctionDef" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Set TyCoVar -> SDoc
forall a. Outputable a => a -> SDoc
ppr Set TyCoVar
s
ppr (InterpreterUse Bool
s Set TyCoVar
ty) = SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
parens (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"InterpreterUse" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Bool -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bool
s SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Set TyCoVar -> SDoc
forall a. Outputable a => a -> SDoc
ppr Set TyCoVar
ty
mustUnify :: SolveContext -> Bool
mustUnify :: SolveContext -> Bool
mustUnify (FunctionDef Set TyCoVar
_) = Bool
True
mustUnify (InterpreterUse Bool
b Set TyCoVar
_) = Bool
b
unify
:: SolveContext
-> Type
-> Type
-> Maybe SUBST
unify :: SolveContext -> Type -> Type -> Maybe Subst
unify SolveContext
solve_ctx = Set TyCoVar -> Type -> Type -> Maybe Subst
tryUnifyUnivarsButNotSkolems Set TyCoVar
skolems
where
skolems :: Set TyVar
skolems :: Set TyCoVar
skolems =
case SolveContext
solve_ctx of
InterpreterUse Bool
_ Set TyCoVar
s -> Set TyCoVar
s
FunctionDef Set TyCoVar
s -> Set TyCoVar
s
#if __GLASGOW_HASKELL__ >= 902
#define BINDME (const BindMe)
#define APART (const Apart)
#else
#define BINDME BindMe
#define APART Skolem
#endif
tryUnifyUnivarsButNotSkolems :: Set TyVar -> Type -> Type -> Maybe SUBST
tryUnifyUnivarsButNotSkolems :: Set TyCoVar -> Type -> Type -> Maybe Subst
tryUnifyUnivarsButNotSkolems Set TyCoVar
skolems Type
goal Type
inst =
case BindFun -> [Type] -> [Type] -> UnifyResult
tcUnifyTysFG
((Type -> BindFlag)
-> (Type -> BindFlag) -> Bool -> Type -> BindFlag
forall a. a -> a -> Bool -> a
bool BINDME APART . flip S.member skolems)
[Type
inst]
[Type
goal] of
Unifiable Subst
subst -> Subst -> Maybe Subst
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Subst
subst
UnifyResult
_ -> Maybe Subst
forall a. Maybe a
Nothing
data Unification = Unification
{ Unification -> OrdType
_unifyLHS :: OrdType
, Unification -> OrdType
_unifyRHS :: OrdType
}
deriving (Unification -> Unification -> Bool
(Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool) -> Eq Unification
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Unification -> Unification -> Bool
== :: Unification -> Unification -> Bool
$c/= :: Unification -> Unification -> Bool
/= :: Unification -> Unification -> Bool
Eq, Eq Unification
Eq Unification =>
(Unification -> Unification -> Ordering)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Unification)
-> (Unification -> Unification -> Unification)
-> Ord Unification
Unification -> Unification -> Bool
Unification -> Unification -> Ordering
Unification -> Unification -> Unification
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Unification -> Unification -> Ordering
compare :: Unification -> Unification -> Ordering
$c< :: Unification -> Unification -> Bool
< :: Unification -> Unification -> Bool
$c<= :: Unification -> Unification -> Bool
<= :: Unification -> Unification -> Bool
$c> :: Unification -> Unification -> Bool
> :: Unification -> Unification -> Bool
$c>= :: Unification -> Unification -> Bool
>= :: Unification -> Unification -> Bool
$cmax :: Unification -> Unification -> Unification
max :: Unification -> Unification -> Unification
$cmin :: Unification -> Unification -> Unification
min :: Unification -> Unification -> Unification
Ord)
newtype OrdType = OrdType
{ OrdType -> Type
getOrdType :: Type
}
instance Eq OrdType where
== :: OrdType -> OrdType -> Bool
(==) = Type -> Type -> Bool
eqType (Type -> Type -> Bool)
-> (OrdType -> Type) -> OrdType -> OrdType -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType
instance Ord OrdType where
compare :: OrdType -> OrdType -> Ordering
compare = Type -> Type -> Ordering
nonDetCmpType (Type -> Type -> Ordering)
-> (OrdType -> Type) -> OrdType -> OrdType -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType
unzipNewWanteds
:: S.Set Unification
-> [(Unification, Ct)]
-> ([Unification], [Ct])
unzipNewWanteds :: Set Unification -> [(Unification, Ct)] -> ([Unification], [Ct])
unzipNewWanteds Set Unification
old = [(Unification, Ct)] -> ([Unification], [Ct])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Unification, Ct)] -> ([Unification], [Ct]))
-> ([(Unification, Ct)] -> [(Unification, Ct)])
-> [(Unification, Ct)]
-> ([Unification], [Ct])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Unification, Ct) -> Bool)
-> [(Unification, Ct)] -> [(Unification, Ct)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Unification, Ct) -> Bool) -> (Unification, Ct) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unification -> Set Unification -> Bool)
-> Set Unification -> Unification -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Unification -> Set Unification -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Set Unification
old (Unification -> Bool)
-> ((Unification, Ct) -> Unification) -> (Unification, Ct) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unification, Ct) -> Unification
forall a b. (a, b) -> a
fst)