{-# LANGUAGE CPP #-}

module Polysemy.Plugin.Fundep.Unification where

import           Data.Bool
import           Data.Function (on)
import           Data.Set (Set)
import qualified Data.Set as S

#if __GLASGOW_HASKELL__ >= 900
import           GHC.Tc.Types.Constraint
#elif __GLASGOW_HASKELL__ >= 810
import           Constraint
#else
import           TcRnTypes
#endif

#if __GLASGOW_HASKELL__ >= 900
import           GHC.Core.Type
import           GHC.Core.Unify
import           GHC.Plugins (Outputable, ppr, parens, text, (<+>))
#else
import           Type
import           Unify
import           GhcPlugins (Outputable, ppr, parens, text, (<+>))
#endif

#if __GLASGOW_HASKELL__ >= 906
#define SUBST Subst
import           GHC.Core.TyCo.Subst (SUBST)
import           GHC.Core.TyCo.Compare (eqType, nonDetCmpType)
#else
#define SUBST TCvSubst
#endif


------------------------------------------------------------------------------
-- | The context in which we're attempting to solve a constraint.
data SolveContext
  = -- | In the context of a function definition. The @Set TyVar@ is all of the
    -- skolems that exist in the [G] constraints for this function.
    FunctionDef (Set TyVar)
    -- | In the context of running an interpreter. The 'Bool' corresponds to
    -- whether we are only trying to solve a single 'Member' constraint right
    -- now. If so, we *must* produce a unification wanted.
  | InterpreterUse Bool (Set TyVar)
  deriving (SolveContext -> SolveContext -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SolveContext -> SolveContext -> Bool
$c/= :: SolveContext -> SolveContext -> Bool
== :: SolveContext -> SolveContext -> Bool
$c== :: SolveContext -> SolveContext -> Bool
Eq, Eq SolveContext
SolveContext -> SolveContext -> Bool
SolveContext -> SolveContext -> Ordering
SolveContext -> SolveContext -> SolveContext
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SolveContext -> SolveContext -> SolveContext
$cmin :: SolveContext -> SolveContext -> SolveContext
max :: SolveContext -> SolveContext -> SolveContext
$cmax :: SolveContext -> SolveContext -> SolveContext
>= :: SolveContext -> SolveContext -> Bool
$c>= :: SolveContext -> SolveContext -> Bool
> :: SolveContext -> SolveContext -> Bool
$c> :: SolveContext -> SolveContext -> Bool
<= :: SolveContext -> SolveContext -> Bool
$c<= :: SolveContext -> SolveContext -> Bool
< :: SolveContext -> SolveContext -> Bool
$c< :: SolveContext -> SolveContext -> Bool
compare :: SolveContext -> SolveContext -> Ordering
$ccompare :: SolveContext -> SolveContext -> Ordering
Ord)

instance Outputable SolveContext where
  ppr :: SolveContext -> SDoc
ppr (FunctionDef Set TyCoVar
s) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
"FunctionDef" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Set TyCoVar
s
  ppr (InterpreterUse Bool
s Set TyCoVar
ty) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
"InterpreterUse" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Bool
s SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Set TyCoVar
ty


------------------------------------------------------------------------------
-- | Depending on the context in which we're solving a constraint, we may or
-- may not want to force a unification of effects. For example, when defining
-- user code whose type is @Member (State Int) r => ...@, if we see @get :: Sem
-- r s@, we should unify @s ~ Int@.
mustUnify :: SolveContext -> Bool
mustUnify :: SolveContext -> Bool
mustUnify (FunctionDef Set TyCoVar
_) = Bool
True
mustUnify (InterpreterUse Bool
b Set TyCoVar
_) = Bool
b


------------------------------------------------------------------------------
-- | Determine whether or not two effects are unifiable.
--
-- All free variables in [W] constraints are considered skolems, and thus are
-- not allowed to unify with anything but themselves. This properly handles all
-- cases in which we are unifying ambiguous [W] constraints (which are true
-- type variables) against [G] constraints.
unify
    :: SolveContext
    -> Type  -- ^ wanted
    -> Type  -- ^ given
    -> Maybe SUBST
unify :: SolveContext -> Type -> Type -> Maybe TCvSubst
unify SolveContext
solve_ctx = Set TyCoVar -> Type -> Type -> Maybe TCvSubst
tryUnifyUnivarsButNotSkolems Set TyCoVar
skolems
  where
    skolems :: Set TyVar
    skolems :: Set TyCoVar
skolems =
      case SolveContext
solve_ctx of
        InterpreterUse Bool
_ Set TyCoVar
s -> Set TyCoVar
s
        FunctionDef Set TyCoVar
s      -> Set TyCoVar
s

#if __GLASGOW_HASKELL__ >= 902
#define BINDME (const BindMe)
#define APART (const Apart)
#else
#define BINDME BindMe
#define APART Skolem
#endif

tryUnifyUnivarsButNotSkolems :: Set TyVar -> Type -> Type -> Maybe SUBST
tryUnifyUnivarsButNotSkolems :: Set TyCoVar -> Type -> Type -> Maybe TCvSubst
tryUnifyUnivarsButNotSkolems Set TyCoVar
skolems Type
goal Type
inst =
  case BindFun -> [Type] -> [Type] -> UnifyResult
tcUnifyTysFG
         (forall a. a -> a -> Bool -> a
bool BINDME APART . flip S.member skolems)
         [Type
inst]
         [Type
goal] of
    Unifiable TCvSubst
subst -> forall (f :: * -> *) a. Applicative f => a -> f a
pure TCvSubst
subst
    UnifyResult
_               -> forall a. Maybe a
Nothing

------------------------------------------------------------------------------
-- | A wrapper for two types that we want to say have been unified.
data Unification = Unification
  { Unification -> OrdType
_unifyLHS :: OrdType
  , Unification -> OrdType
_unifyRHS :: OrdType
  }
  deriving (Unification -> Unification -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Unification -> Unification -> Bool
$c/= :: Unification -> Unification -> Bool
== :: Unification -> Unification -> Bool
$c== :: Unification -> Unification -> Bool
Eq, Eq Unification
Unification -> Unification -> Bool
Unification -> Unification -> Ordering
Unification -> Unification -> Unification
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Unification -> Unification -> Unification
$cmin :: Unification -> Unification -> Unification
max :: Unification -> Unification -> Unification
$cmax :: Unification -> Unification -> Unification
>= :: Unification -> Unification -> Bool
$c>= :: Unification -> Unification -> Bool
> :: Unification -> Unification -> Bool
$c> :: Unification -> Unification -> Bool
<= :: Unification -> Unification -> Bool
$c<= :: Unification -> Unification -> Bool
< :: Unification -> Unification -> Bool
$c< :: Unification -> Unification -> Bool
compare :: Unification -> Unification -> Ordering
$ccompare :: Unification -> Unification -> Ordering
Ord)


------------------------------------------------------------------------------
-- | 'Type's don't have 'Eq' or 'Ord' instances by default, even though there
-- are functions in GHC that implement these operations. This newtype gives us
-- those instances.
newtype OrdType = OrdType
  { OrdType -> Type
getOrdType :: Type
  }

instance Eq OrdType where
  == :: OrdType -> OrdType -> Bool
(==) = Type -> Type -> Bool
eqType forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType

instance Ord OrdType where
  compare :: OrdType -> OrdType -> Ordering
compare = Type -> Type -> Ordering
nonDetCmpType forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType


------------------------------------------------------------------------------
-- | Filter out the unifications we've already emitted, and then give back the
-- things we should put into the @S.Set Unification@, and the new constraints
-- we should emit.
unzipNewWanteds
    :: S.Set Unification
    -> [(Unification, Ct)]
    -> ([Unification], [Ct])
unzipNewWanteds :: Set Unification -> [(Unification, Ct)] -> ([Unification], [Ct])
unzipNewWanteds Set Unification
old = forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> Set a -> Bool
S.member Set Unification
old forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)