module Polysemy.Plugin.Fundep.Unification where

import           Data.Bool
import           Data.Function (on)
import qualified Data.Set as S
import           TcRnTypes
import           Type


------------------------------------------------------------------------------
-- | The context in which we're attempting to solve a constraint.
data SolveContext
  = -- | In the context of a function definition.
    FunctionDef
    -- | In the context of running an interpreter. The 'Bool' corresponds to
    -- whether we are only trying to solve a single 'Member' constraint right
    -- now. If so, we *must* produce a unification wanted.
  | InterpreterUse Bool
  deriving (Eq, Ord, Show)


------------------------------------------------------------------------------
-- | Depending on the context in which we're solving a constraint, we may or
-- may not want to force a unification of effects. For example, when defining
-- user code whose type is @Member (State Int) r => ...@, if we see @get :: Sem
-- r s@, we should unify @s ~ Int@.
mustUnify :: SolveContext -> Bool
mustUnify FunctionDef = True
mustUnify (InterpreterUse b) = b


------------------------------------------------------------------------------
-- | Determine whether or not two effects are unifiable. This is nuanced.
--
-- There are several cases:
--
-- 1. [W] ∀ e1. e1   [G] ∀ e2. e2
--    Always fails, because we never want to unify two effects if effect names
--    are polymorphic.
--
-- 2. [W] State s    [G] State Int
--    Always succeeds. It's safe to take our given as a fundep annotation.
--
-- 3. [W] State Int  [G] State s
--        (when the [G] is a given that comes from a type signature)
--
--    This should fail, because it means we wrote the type signature @Member
--    (State s) r => ...@, but are trying to use @s@ as an @Int@. Clearly
--    bogus!
--
-- 4. [W] State Int  [G] State s
--        (when the [G] was generated by running an interpreter)
--
--    Sometimes OK, but only if the [G] is the only thing we're trying to solve
--    right now. Consider the case:
--
--      runState 5 $ pure @(Sem (State Int ': r)) ()
--
--    Here we have  [G] forall a. Num a => State a  and  [W] State Int. Clearly
--    the typechecking should flow "backwards" here, out of the row and into
--    the type of 'runState'.
--
--    What happens if there are multiple [G]s in scope for the same @r@? Then
--    we'd emit multiple unification constraints for the same effect but with
--    different polymorphic variables, which would unify a bunch of effects
--    that shouldn't be!
canUnifyRecursive
    :: SolveContext
    -> Type  -- ^ wanted
    -> Type  -- ^ given
    -> Bool
canUnifyRecursive solve_ctx = go True
  where
    -- It's only OK to solve a polymorphic "given" if we're in the context of
    -- an interpreter, because it's not really a given!
    poly_given_ok :: Bool
    poly_given_ok =
      case solve_ctx of
        InterpreterUse _ -> True
        FunctionDef      -> False

    -- On the first go around, we don't want to unify effects with tyvars, but
    -- we _do_ want to unify their arguments, thus 'is_first'.
    go :: Bool -> Type -> Type -> Bool
    go is_first wanted given =
      let (w, ws) = splitAppTys wanted
          (g, gs) = splitAppTys given
       in (&& bool (canUnify poly_given_ok) eqType is_first w g)
        . flip all (zip ws gs)
        $ \(wt, gt) -> canUnify poly_given_ok wt gt || go False wt gt


------------------------------------------------------------------------------
-- | A non-recursive version of 'canUnifyRecursive'.
canUnify :: Bool -> Type -> Type -> Bool
canUnify poly_given_ok wt gt =
  or [ isTyVarTy wt
     , isTyVarTy gt && poly_given_ok
     , eqType wt gt
     ]


------------------------------------------------------------------------------
-- | A wrapper for two types that we want to say have been unified.
data Unification = Unification
  { _unifyLHS :: OrdType
  , _unifyRHS :: OrdType
  }
  deriving (Eq, Ord)


------------------------------------------------------------------------------
-- | 'Type's don't have 'Eq' or 'Ord' instances by default, even though there
-- are functions in GHC that implement these operations. This newtype gives us
-- those instances.
newtype OrdType = OrdType
  { getOrdType :: Type
  }

instance Eq OrdType where
  (==) = eqType `on` getOrdType

instance Ord OrdType where
  compare = nonDetCmpType `on` getOrdType


------------------------------------------------------------------------------
-- | Filter out the unifications we've already emitted, and then give back the
-- things we should put into the @S.Set Unification@, and the new constraints
-- we should emit.
unzipNewWanteds
    :: S.Set Unification
    -> [(Unification, Ct)]
    -> ([Unification], [Ct])
unzipNewWanteds old = unzip . filter (not . flip S.member old . fst)