{-# LANGUAGE CPP                       #-}

module Polysemy.Plugin.Fundep.Unification where

import           Data.Bool
import           Data.Function (on)
import qualified Data.Set as S
#if __GLASGOW_HASKELL__ >= 810
import           Constraint
#else
import           TcRnTypes
#endif

import           Type


------------------------------------------------------------------------------
-- | The context in which we're attempting to solve a constraint.
data SolveContext
  = -- | In the context of a function definition.
    FunctionDef
    -- | In the context of running an interpreter. The 'Bool' corresponds to
    -- whether we are only trying to solve a single 'Member' constraint right
    -- now. If so, we *must* produce a unification wanted.
  | InterpreterUse Bool
  deriving (SolveContext -> SolveContext -> Bool
(SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool) -> Eq SolveContext
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SolveContext -> SolveContext -> Bool
$c/= :: SolveContext -> SolveContext -> Bool
== :: SolveContext -> SolveContext -> Bool
$c== :: SolveContext -> SolveContext -> Bool
Eq, Eq SolveContext
Eq SolveContext
-> (SolveContext -> SolveContext -> Ordering)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> Bool)
-> (SolveContext -> SolveContext -> SolveContext)
-> (SolveContext -> SolveContext -> SolveContext)
-> Ord SolveContext
SolveContext -> SolveContext -> Bool
SolveContext -> SolveContext -> Ordering
SolveContext -> SolveContext -> SolveContext
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SolveContext -> SolveContext -> SolveContext
$cmin :: SolveContext -> SolveContext -> SolveContext
max :: SolveContext -> SolveContext -> SolveContext
$cmax :: SolveContext -> SolveContext -> SolveContext
>= :: SolveContext -> SolveContext -> Bool
$c>= :: SolveContext -> SolveContext -> Bool
> :: SolveContext -> SolveContext -> Bool
$c> :: SolveContext -> SolveContext -> Bool
<= :: SolveContext -> SolveContext -> Bool
$c<= :: SolveContext -> SolveContext -> Bool
< :: SolveContext -> SolveContext -> Bool
$c< :: SolveContext -> SolveContext -> Bool
compare :: SolveContext -> SolveContext -> Ordering
$ccompare :: SolveContext -> SolveContext -> Ordering
$cp1Ord :: Eq SolveContext
Ord, Int -> SolveContext -> ShowS
[SolveContext] -> ShowS
SolveContext -> String
(Int -> SolveContext -> ShowS)
-> (SolveContext -> String)
-> ([SolveContext] -> ShowS)
-> Show SolveContext
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SolveContext] -> ShowS
$cshowList :: [SolveContext] -> ShowS
show :: SolveContext -> String
$cshow :: SolveContext -> String
showsPrec :: Int -> SolveContext -> ShowS
$cshowsPrec :: Int -> SolveContext -> ShowS
Show)


------------------------------------------------------------------------------
-- | Depending on the context in which we're solving a constraint, we may or
-- may not want to force a unification of effects. For example, when defining
-- user code whose type is @Member (State Int) r => ...@, if we see @get :: Sem
-- r s@, we should unify @s ~ Int@.
mustUnify :: SolveContext -> Bool
mustUnify :: SolveContext -> Bool
mustUnify SolveContext
FunctionDef = Bool
True
mustUnify (InterpreterUse Bool
b) = Bool
b


------------------------------------------------------------------------------
-- | Determine whether or not two effects are unifiable. This is nuanced.
--
-- There are several cases:
--
-- 1. [W] ∀ e1. e1   [G] ∀ e2. e2
--    Always fails, because we never want to unify two effects if effect names
--    are polymorphic.
--
-- 2. [W] State s    [G] State Int
--    Always succeeds. It's safe to take our given as a fundep annotation.
--
-- 3. [W] State Int  [G] State s
--        (when the [G] is a given that comes from a type signature)
--
--    This should fail, because it means we wrote the type signature @Member
--    (State s) r => ...@, but are trying to use @s@ as an @Int@. Clearly
--    bogus!
--
-- 4. [W] State Int  [G] State s
--        (when the [G] was generated by running an interpreter)
--
--    Sometimes OK, but only if the [G] is the only thing we're trying to solve
--    right now. Consider the case:
--
--      runState 5 $ pure @(Sem (State Int ': r)) ()
--
--    Here we have  [G] forall a. Num a => State a  and  [W] State Int. Clearly
--    the typechecking should flow "backwards" here, out of the row and into
--    the type of 'runState'.
--
--    What happens if there are multiple [G]s in scope for the same @r@? Then
--    we'd emit multiple unification constraints for the same effect but with
--    different polymorphic variables, which would unify a bunch of effects
--    that shouldn't be!
canUnifyRecursive
    :: SolveContext
    -> Type  -- ^ wanted
    -> Type  -- ^ given
    -> Bool
canUnifyRecursive :: SolveContext -> Type -> Type -> Bool
canUnifyRecursive SolveContext
solve_ctx = Bool -> Type -> Type -> Bool
go Bool
True
  where
    -- It's only OK to solve a polymorphic "given" if we're in the context of
    -- an interpreter, because it's not really a given!
    poly_given_ok :: Bool
    poly_given_ok :: Bool
poly_given_ok =
      case SolveContext
solve_ctx of
        InterpreterUse Bool
_ -> Bool
True
        SolveContext
FunctionDef      -> Bool
False

    -- On the first go around, we don't want to unify effects with tyvars, but
    -- we _do_ want to unify their arguments, thus 'is_first'.
    go :: Bool -> Type -> Type -> Bool
    go :: Bool -> Type -> Type -> Bool
go Bool
is_first Type
wanted Type
given =
      let (Type
w, [Type]
ws) = Type -> (Type, [Type])
splitAppTys Type
wanted
          (Type
g, [Type]
gs) = Type -> (Type, [Type])
splitAppTys Type
given
       in (Bool -> Bool -> Bool
&& (Type -> Type -> Bool)
-> (Type -> Type -> Bool) -> Bool -> Type -> Type -> Bool
forall a. a -> a -> Bool -> a
bool (Bool -> Type -> Type -> Bool
canUnify Bool
poly_given_ok) Type -> Type -> Bool
eqType Bool
is_first Type
w Type
g)
        (Bool -> Bool)
-> (((Type, Type) -> Bool) -> Bool)
-> ((Type, Type) -> Bool)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Type, Type) -> Bool) -> [(Type, Type)] -> Bool)
-> [(Type, Type)] -> ((Type, Type) -> Bool) -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Type, Type) -> Bool) -> [(Type, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ws [Type]
gs)
        (((Type, Type) -> Bool) -> Bool) -> ((Type, Type) -> Bool) -> Bool
forall a b. (a -> b) -> a -> b
$ \(Type
wt, Type
gt) -> Bool -> Type -> Type -> Bool
canUnify Bool
poly_given_ok Type
wt Type
gt Bool -> Bool -> Bool
|| Bool -> Type -> Type -> Bool
go Bool
False Type
wt Type
gt


------------------------------------------------------------------------------
-- | A non-recursive version of 'canUnifyRecursive'.
canUnify :: Bool -> Type -> Type -> Bool
canUnify :: Bool -> Type -> Type -> Bool
canUnify Bool
poly_given_ok Type
wt Type
gt =
  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [ Type -> Bool
isTyVarTy Type
wt
     , Type -> Bool
isTyVarTy Type
gt Bool -> Bool -> Bool
&& Bool
poly_given_ok
     , Type -> Type -> Bool
eqType Type
wt Type
gt
     ]


------------------------------------------------------------------------------
-- | A wrapper for two types that we want to say have been unified.
data Unification = Unification
  { Unification -> OrdType
_unifyLHS :: OrdType
  , Unification -> OrdType
_unifyRHS :: OrdType
  }
  deriving (Unification -> Unification -> Bool
(Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool) -> Eq Unification
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Unification -> Unification -> Bool
$c/= :: Unification -> Unification -> Bool
== :: Unification -> Unification -> Bool
$c== :: Unification -> Unification -> Bool
Eq, Eq Unification
Eq Unification
-> (Unification -> Unification -> Ordering)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Bool)
-> (Unification -> Unification -> Unification)
-> (Unification -> Unification -> Unification)
-> Ord Unification
Unification -> Unification -> Bool
Unification -> Unification -> Ordering
Unification -> Unification -> Unification
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Unification -> Unification -> Unification
$cmin :: Unification -> Unification -> Unification
max :: Unification -> Unification -> Unification
$cmax :: Unification -> Unification -> Unification
>= :: Unification -> Unification -> Bool
$c>= :: Unification -> Unification -> Bool
> :: Unification -> Unification -> Bool
$c> :: Unification -> Unification -> Bool
<= :: Unification -> Unification -> Bool
$c<= :: Unification -> Unification -> Bool
< :: Unification -> Unification -> Bool
$c< :: Unification -> Unification -> Bool
compare :: Unification -> Unification -> Ordering
$ccompare :: Unification -> Unification -> Ordering
$cp1Ord :: Eq Unification
Ord)


------------------------------------------------------------------------------
-- | 'Type's don't have 'Eq' or 'Ord' instances by default, even though there
-- are functions in GHC that implement these operations. This newtype gives us
-- those instances.
newtype OrdType = OrdType
  { OrdType -> Type
getOrdType :: Type
  }

instance Eq OrdType where
  == :: OrdType -> OrdType -> Bool
(==) = Type -> Type -> Bool
eqType (Type -> Type -> Bool)
-> (OrdType -> Type) -> OrdType -> OrdType -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType

instance Ord OrdType where
  compare :: OrdType -> OrdType -> Ordering
compare = Type -> Type -> Ordering
nonDetCmpType (Type -> Type -> Ordering)
-> (OrdType -> Type) -> OrdType -> OrdType -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` OrdType -> Type
getOrdType


------------------------------------------------------------------------------
-- | Filter out the unifications we've already emitted, and then give back the
-- things we should put into the @S.Set Unification@, and the new constraints
-- we should emit.
unzipNewWanteds
    :: S.Set Unification
    -> [(Unification, Ct)]
    -> ([Unification], [Ct])
unzipNewWanteds :: Set Unification -> [(Unification, Ct)] -> ([Unification], [Ct])
unzipNewWanteds Set Unification
old = [(Unification, Ct)] -> ([Unification], [Ct])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Unification, Ct)] -> ([Unification], [Ct]))
-> ([(Unification, Ct)] -> [(Unification, Ct)])
-> [(Unification, Ct)]
-> ([Unification], [Ct])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Unification, Ct) -> Bool)
-> [(Unification, Ct)] -> [(Unification, Ct)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Unification, Ct) -> Bool) -> (Unification, Ct) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unification -> Set Unification -> Bool)
-> Set Unification -> Unification -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Unification -> Set Unification -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Set Unification
old (Unification -> Bool)
-> ((Unification, Ct) -> Unification) -> (Unification, Ct) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unification, Ct) -> Unification
forall a b. (a, b) -> a
fst)