-- | Hindley-Milner type inference with ergonomic blame assignment.
--
-- 'blame' is a type-error blame assignment algorithm for languages with Hindley-Milner type inference,
-- but __/without generalization of intermediate terms/__.
-- This means that it is not suitable for languages with let-generalization.
-- 'Hyper.Type.AST.Let.Let' is an example of a term that is not suitable for this algorithm.
--
-- With the contemporary knowledge that
-- ["Let Should Not Be Generalised"](https://www.microsoft.com/en-us/research/publication/let-should-not-be-generalised/),
-- as argued by luminaries such as Simon Peyton Jones,
-- optimistically this limitation shouldn't apply to new programming languages.
-- This blame assignment algorithm can also be used in a limited sense for existing languages,
-- which do have let-generalization, to provide better type errors
-- in specific definitions which don't happen to use generalizing terms.
--
-- The algorithm is pretty simple:
--
-- * Invoke all the 'inferBody' calls as 'Hyper.Infer.infer' normally would,
--   but with one important difference:
--   where 'inferBody' would normally get the actual inference results of its child nodes,
--   placeholders are generated in their place
-- * Globally sort all of the tree nodes according to a given node prioritization
--   (this prioritization would be custom for each language)
-- * According to the order of prioritization,
--   attempt to unify each infer-result with its placeholder using 'inferOfUnify'.
--   If a unification fails, roll back its state changes.
--   The nodes whose unification failed are the ones assigned with type errors.
--
-- [Lamdu](https://github.com/lamdu/lamdu) uses this algorithm for its "insist type" feature,
-- which moves around the blame for type mismatches.
--
-- Note: If a similar algorithm already existed somewhere,
-- [I](https://github.com/yairchu/) would very much like to know!

{-# LANGUAGE FlexibleContexts, TemplateHaskell, FlexibleInstances, UndecidableInstances #-}

module Hyper.Infer.Blame
    ( blame
    , Blame(..)
    , BlameResult(..), _Good, _Mismatch
    , InferOf'
    ) where

import qualified Control.Lens as Lens
import           Control.Monad.Except (MonadError(..))
import           Data.List (sortOn)
import           Hyper
import           Hyper.Class.Infer
import           Hyper.Class.Traversable (ContainedH(..))
import           Hyper.Class.Unify (UnifyGen, UVarOf)
import           Hyper.Infer.Result
import           Hyper.Recurse
import           Hyper.Unify.New (newUnbound)
import           Hyper.Unify.Occurs (occursCheck)

import           Hyper.Internal.Prelude

-- | Class implementing some primitives needed by the 'blame' algorithm
--
-- The 'blamableRecursive' method represents that 'Blame' applies to all recursive child nodes.
-- It replaces context for 'Blame' to avoid @UndecidableSuperClasses@.
class
    (Infer m t, RTraversable t, HTraversable (InferOf t), HPointed (InferOf t)) =>
    Blame m t where

    -- | Unify the types/values in infer results
    inferOfUnify ::
        Proxy t ->
        InferOf t # UVarOf m ->
        InferOf t # UVarOf m ->
        m ()

    -- | Check whether two infer results are the same
    inferOfMatches ::
        Proxy t ->
        InferOf t # UVarOf m ->
        InferOf t # UVarOf m ->
        m Bool

    -- TODO: Putting documentation here causes duplication in the haddock documentation
    blamableRecursive ::
        Proxy m -> Proxy t -> Dict (HNodesConstraint t (Blame m))
    {-# INLINE blamableRecursive #-}
    default blamableRecursive ::
        HNodesConstraint t (Blame m) =>
        Proxy m -> Proxy t -> Dict (HNodesConstraint t (Blame m))
    blamableRecursive Proxy m
_ Proxy t
_ = Dict (HNodesConstraint t (Blame m))
forall (a :: Constraint). a => Dict a
Dict

instance Recursive (Blame m) where
    recurse :: proxy (Blame m h) -> Dict (HNodesConstraint h (Blame m))
recurse = Proxy m -> Proxy h -> Dict (HNodesConstraint h (Blame m))
forall (m :: * -> *) (t :: HyperType).
Blame m t =>
Proxy m -> Proxy t -> Dict (HNodesConstraint t (Blame m))
blamableRecursive (Proxy m
forall k (t :: k). Proxy t
Proxy @m) (Proxy h -> Dict (HNodesConstraint h (Blame m)))
-> (proxy (Blame m h) -> Proxy h)
-> proxy (Blame m h)
-> Dict (HNodesConstraint h (Blame m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. proxy (Blame m h) -> Proxy h
forall (proxy :: Constraint -> *) (f :: HyperType -> Constraint)
       (h :: HyperType).
proxy (f h) -> Proxy h
proxyArgument

-- | A type synonym to help 'BlameResult' be more succinct
type InferOf' e v = InferOf (GetHyperType e) # v

prepareH ::
    forall m exp a.
    Blame m exp =>
    Ann a # exp ->
    m (Ann (a :*: InferResult (UVarOf m) :*: InferResult (UVarOf m)) # exp)
prepareH :: (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
prepareH Ann a # exp
t =
    Dict
  (HNodesConstraint exp (Infer m),
   HNodesConstraint (InferOf exp) (UnifyGen m))
-> ((HNodesConstraint exp (Infer m),
     HNodesConstraint (InferOf exp) (UnifyGen m)) =>
    m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy m
-> Proxy exp
-> Dict
     (HNodesConstraint exp (Infer m),
      HNodesConstraint (InferOf exp) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext (Proxy m
forall k (t :: k). Proxy t
Proxy @m) (Proxy exp
forall k (t :: k). Proxy t
Proxy @exp)) (((HNodesConstraint exp (Infer m),
   HNodesConstraint (InferOf exp) (UnifyGen m)) =>
  m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     # exp))
 -> m (Ann
         (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> ((HNodesConstraint exp (Infer m),
     HNodesConstraint (InferOf exp) (UnifyGen m)) =>
    m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall a b. (a -> b) -> a -> b
$
    (forall (n :: HyperType).
 HWitness (InferOf exp) n -> ContainedH m (UVarOf m) # n)
-> InferOf exp # ContainedH m (UVarOf m)
forall (h :: HyperType) (p :: HyperType).
HPointed h =>
(forall (n :: HyperType). HWitness h n -> p # n) -> h # p
hpure (Proxy (UnifyGen m)
forall k (t :: k). Proxy t
Proxy @(UnifyGen m) Proxy (UnifyGen m)
-> (UnifyGen m n => ContainedH m (UVarOf m) # n)
-> HWitness (InferOf exp) n
-> ContainedH m (UVarOf m) # n
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> m (UVarOf m ('AHyperType n)) -> ContainedH m (UVarOf m) # n
forall (f :: * -> *) (p :: HyperType) (h :: AHyperType).
f (p h) -> ContainedH f p h
MkContainedH m (UVarOf m ('AHyperType n))
forall (m :: * -> *) (t :: HyperType).
UnifyGen m t =>
m (UVarOf m # t)
newUnbound)
    (InferOf exp # ContainedH m (UVarOf m))
-> ((InferOf exp # ContainedH m (UVarOf m))
    -> m (InferOf exp # UVarOf m))
-> m (InferOf exp # UVarOf m)
forall a b. a -> (a -> b) -> b
& (InferOf exp # ContainedH m (UVarOf m))
-> m (InferOf exp # UVarOf m)
forall (h :: HyperType) (f :: * -> *) (p :: HyperType).
(HTraversable h, Applicative f) =>
(h # ContainedH f p) -> f (h # p)
hsequence
    m (InferOf exp # UVarOf m)
-> ((InferOf exp # UVarOf m)
    -> m (Ann
            (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
          # exp))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (m :: * -> *) (exp :: HyperType) (a :: HyperType).
Blame m exp =>
(InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
`prepare` Ann a # exp
t)

prepare ::
    forall m exp a.
    Blame m exp =>
    InferOf exp # UVarOf m ->
    Ann a # exp ->
    m (Ann (a :*: InferResult (UVarOf m) :*: InferResult (UVarOf m)) # exp)
prepare :: (InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
prepare InferOf exp # UVarOf m
resFromPosition (Ann a ('AHyperType exp)
a 'AHyperType exp :# Ann a
x) =
    Dict (HNodesConstraint exp (Blame m))
-> (HNodesConstraint exp (Blame m) =>
    m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy (Blame m exp) -> Dict (HNodesConstraint exp (Blame m))
forall (c :: HyperType -> Constraint) (h :: HyperType)
       (proxy :: Constraint -> *).
(Recursive c, HNodes h, c h) =>
proxy (c h) -> Dict (HNodesConstraint h c)
recurse (Proxy (Blame m exp)
forall k (t :: k). Proxy t
Proxy @(Blame m exp))) ((HNodesConstraint exp (Blame m) =>
  m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     # exp))
 -> m (Ann
         (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> (HNodesConstraint exp (Blame m) =>
    m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall a b. (a -> b) -> a -> b
$
    (forall (n :: HyperType).
 HWitness exp n
 -> (Ann a # n)
 -> InferChild
      m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
    # n)
-> (exp # Ann a)
-> exp
   # InferChild
       m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
forall (h :: HyperType) (p :: HyperType) (q :: HyperType).
HFunctor h =>
(forall (n :: HyperType). HWitness h n -> (p # n) -> q # n)
-> (h # p) -> h # q
hmap
    ( Proxy (Blame m)
forall k (t :: k). Proxy t
Proxy @(Blame m) Proxy (Blame m)
-> (Blame m n =>
    (Ann a # n)
    -> InferChild
         m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
       # n)
-> HWitness exp n
-> (Ann a # n)
-> InferChild
     m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
   # n
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#>
        m (InferredChild
     (UVarOf m)
     (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
     ('AHyperType n))
-> InferChild
     m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
   # n
forall (m :: * -> *) (h :: HyperType) (t :: AHyperType).
m (InferredChild (UVarOf m) h t) -> InferChild m h t
InferChild (m (InferredChild
      (UVarOf m)
      (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
      ('AHyperType n))
 -> InferChild
      m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
    # n)
-> ((Ann a # n)
    -> m (InferredChild
            (UVarOf m)
            (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
            ('AHyperType n)))
-> (Ann a # n)
-> InferChild
     m (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
   # n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ann
   (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
   ('AHyperType n)
 -> InferredChild
      (UVarOf m)
      (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
      ('AHyperType n))
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        ('AHyperType n))
-> m (InferredChild
        (UVarOf m)
        (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
        ('AHyperType n))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Ann
  (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
  ('AHyperType n)
t -> Ann
  (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
  ('AHyperType n)
-> (InferOf (GetHyperType ('AHyperType n)) # UVarOf m)
-> InferredChild
     (UVarOf m)
     (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
     ('AHyperType n)
forall (v :: HyperType) (h :: HyperType) (t :: AHyperType).
h t -> (InferOf (GetHyperType t) # v) -> InferredChild v h t
InferredChild Ann
  (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
  ('AHyperType n)
t (Ann
  (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
  ('AHyperType n)
t Ann
  (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
  ('AHyperType n)
-> Getting
     (InferOf n # UVarOf m)
     (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        ('AHyperType n))
     (InferOf n # UVarOf m)
-> InferOf n # UVarOf m
forall s a. s -> Getting a s a -> a
^. ((:*:)
   a
   (InferResult (UVarOf m) :*: InferResult (UVarOf m))
   ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m)
      ((:*:)
         a
         (InferResult (UVarOf m) :*: InferResult (UVarOf m))
         ('AHyperType n)))
-> Ann
     (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     ('AHyperType n)
-> Const
     (InferOf n # UVarOf m)
     (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        ('AHyperType n))
forall (a :: HyperType) (h :: AHyperType). Lens' (Ann a h) (a h)
hAnn (((:*:)
    a
    (InferResult (UVarOf m) :*: InferResult (UVarOf m))
    ('AHyperType n)
  -> Const
       (InferOf n # UVarOf m)
       ((:*:)
          a
          (InferResult (UVarOf m) :*: InferResult (UVarOf m))
          ('AHyperType n)))
 -> Ann
      (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m)
      (Ann
         (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
         ('AHyperType n)))
-> (((InferOf n # UVarOf m)
     -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
    -> (:*:)
         a
         (InferResult (UVarOf m) :*: InferResult (UVarOf m))
         ('AHyperType n)
    -> Const
         (InferOf n # UVarOf m)
         ((:*:)
            a
            (InferResult (UVarOf m) :*: InferResult (UVarOf m))
            ('AHyperType n)))
-> Getting
     (InferOf n # UVarOf m)
     (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        ('AHyperType n))
     (InferOf n # UVarOf m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((:*:)
   (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m)
      ((:*:)
         (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)))
-> (:*:)
     a
     (InferResult (UVarOf m) :*: InferResult (UVarOf m))
     ('AHyperType n)
-> Const
     (InferOf n # UVarOf m)
     ((:*:)
        a
        (InferResult (UVarOf m) :*: InferResult (UVarOf m))
        ('AHyperType n))
forall s t a b. Field2 s t a b => Lens s t a b
Lens._2 (((:*:)
    (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
  -> Const
       (InferOf n # UVarOf m)
       ((:*:)
          (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)))
 -> (:*:)
      a
      (InferResult (UVarOf m) :*: InferResult (UVarOf m))
      ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m)
      ((:*:)
         a
         (InferResult (UVarOf m) :*: InferResult (UVarOf m))
         ('AHyperType n)))
-> (((InferOf n # UVarOf m)
     -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
    -> (:*:)
         (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
    -> Const
         (InferOf n # UVarOf m)
         ((:*:)
            (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)))
-> ((InferOf n # UVarOf m)
    -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
-> (:*:)
     a
     (InferResult (UVarOf m) :*: InferResult (UVarOf m))
     ('AHyperType n)
-> Const
     (InferOf n # UVarOf m)
     ((:*:)
        a
        (InferResult (UVarOf m) :*: InferResult (UVarOf m))
        ('AHyperType n))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (InferResult (UVarOf m) ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m) (InferResult (UVarOf m) ('AHyperType n)))
-> (:*:)
     (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
-> Const
     (InferOf n # UVarOf m)
     ((:*:)
        (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n))
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 ((InferResult (UVarOf m) ('AHyperType n)
  -> Const
       (InferOf n # UVarOf m) (InferResult (UVarOf m) ('AHyperType n)))
 -> (:*:)
      (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
 -> Const
      (InferOf n # UVarOf m)
      ((:*:)
         (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)))
-> (((InferOf n # UVarOf m)
     -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
    -> InferResult (UVarOf m) ('AHyperType n)
    -> Const
         (InferOf n # UVarOf m) (InferResult (UVarOf m) ('AHyperType n)))
-> ((InferOf n # UVarOf m)
    -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
-> (:*:)
     (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n)
-> Const
     (InferOf n # UVarOf m)
     ((:*:)
        (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType n))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((InferOf n # UVarOf m)
 -> Const (InferOf n # UVarOf m) (InferOf n # UVarOf m))
-> InferResult (UVarOf m) ('AHyperType n)
-> Const
     (InferOf n # UVarOf m) (InferResult (UVarOf m) ('AHyperType n))
forall (v1 :: HyperType) (e1 :: AHyperType) (v :: HyperType)
       (e :: AHyperType).
Iso
  (InferResult v1 e1)
  (InferResult v e)
  (InferOf (GetHyperType e1) # v1)
  (InferOf (GetHyperType e) # v)
_InferResult)) (m (Ann
      (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      ('AHyperType n))
 -> m (InferredChild
         (UVarOf m)
         (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
         ('AHyperType n)))
-> ((Ann a # n)
    -> m (Ann
            (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
            ('AHyperType n)))
-> (Ann a # n)
-> m (InferredChild
        (UVarOf m)
        (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
        ('AHyperType n))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ann a # n)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        ('AHyperType n))
forall (m :: * -> *) (exp :: HyperType) (a :: HyperType).
Blame m exp =>
(Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
prepareH
    ) exp # Ann a
'AHyperType exp :# Ann a
x
    (exp
 # InferChild
     m
     (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))))
-> ((exp
     # InferChild
         m
         (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))))
    -> m (exp
          # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))),
          InferOf exp # UVarOf m))
-> m (exp
      # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))),
      InferOf exp # UVarOf m)
forall a b. a -> (a -> b) -> b
& (exp
 # InferChild
     m
     (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))))
-> m (exp
      # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))),
      InferOf exp # UVarOf m)
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody
    m (exp
   # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))),
   InferOf exp # UVarOf m)
-> ((exp
     # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))),
     InferOf exp # UVarOf m)
    -> Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&>
    \(exp
# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
xI, InferOf exp # UVarOf m
r) ->
    (:*:)
  a
  (InferResult (UVarOf m) :*: InferResult (UVarOf m))
  ('AHyperType exp)
-> ('AHyperType exp
    :# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
-> Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
   # exp
forall (a :: HyperType) (h :: AHyperType).
a h -> (h :# Ann a) -> Ann a h
Ann (a ('AHyperType exp)
a a ('AHyperType exp)
-> (:*:)
     (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType exp)
-> (:*:)
     a
     (InferResult (UVarOf m) :*: InferResult (UVarOf m))
     ('AHyperType exp)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: (InferOf (GetHyperType ('AHyperType exp)) # UVarOf m)
-> InferResult (UVarOf m) ('AHyperType exp)
forall (v :: HyperType) (e :: AHyperType).
(InferOf (GetHyperType e) # v) -> InferResult v e
InferResult InferOf exp # UVarOf m
InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
resFromPosition InferResult (UVarOf m) ('AHyperType exp)
-> InferResult (UVarOf m) ('AHyperType exp)
-> (:*:)
     (InferResult (UVarOf m)) (InferResult (UVarOf m)) ('AHyperType exp)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: (InferOf (GetHyperType ('AHyperType exp)) # UVarOf m)
-> InferResult (UVarOf m) ('AHyperType exp)
forall (v :: HyperType) (e :: AHyperType).
(InferOf (GetHyperType e) # v) -> InferResult v e
InferResult InferOf exp # UVarOf m
InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
r) exp
# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
'AHyperType exp
:# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
xI

tryUnify ::
    forall err m top exp.
    (MonadError err m, Blame m exp) =>
    HWitness top exp ->
    InferOf exp # UVarOf m ->
    InferOf exp # UVarOf m ->
    m ()
tryUnify :: HWitness top exp
-> (InferOf exp # UVarOf m) -> (InferOf exp # UVarOf m) -> m ()
tryUnify HWitness top exp
_ InferOf exp # UVarOf m
i0 InferOf exp # UVarOf m
i1 =
    Dict
  (HNodesConstraint exp (Infer m),
   HNodesConstraint (InferOf exp) (UnifyGen m))
-> ((HNodesConstraint exp (Infer m),
     HNodesConstraint (InferOf exp) (UnifyGen m)) =>
    m ())
-> m ()
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy m
-> Proxy exp
-> Dict
     (HNodesConstraint exp (Infer m),
      HNodesConstraint (InferOf exp) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext (Proxy m
forall k (t :: k). Proxy t
Proxy @m) (Proxy exp
forall k (t :: k). Proxy t
Proxy @exp)) (((HNodesConstraint exp (Infer m),
   HNodesConstraint (InferOf exp) (UnifyGen m)) =>
  m ())
 -> m ())
-> ((HNodesConstraint exp (Infer m),
     HNodesConstraint (InferOf exp) (UnifyGen m)) =>
    m ())
-> m ()
forall a b. (a -> b) -> a -> b
$
    do
        Proxy exp
-> (InferOf exp # UVarOf m) -> (InferOf exp # UVarOf m) -> m ()
forall (m :: * -> *) (t :: HyperType).
Blame m t =>
Proxy t -> (InferOf t # UVarOf m) -> (InferOf t # UVarOf m) -> m ()
inferOfUnify (Proxy exp
forall k (t :: k). Proxy t
Proxy @exp) InferOf exp # UVarOf m
i0 InferOf exp # UVarOf m
i1
        (forall (c :: HyperType).
 HWitness (InferOf exp) c -> (UVarOf m # c) -> m ())
-> (InferOf exp # UVarOf m) -> m ()
forall (f :: * -> *) (h :: HyperType) (m :: HyperType).
(Applicative f, HFoldable h) =>
(forall (c :: HyperType). HWitness h c -> (m # c) -> f ())
-> (h # m) -> f ()
htraverse_ (Proxy (UnifyGen m)
forall k (t :: k). Proxy t
Proxy @(UnifyGen m) Proxy (UnifyGen m)
-> (UnifyGen m c => (UVarOf m # c) -> m ())
-> HWitness (InferOf exp) c
-> (UVarOf m # c)
-> m ()
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> UnifyGen m c => (UVarOf m # c) -> m ()
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m ()
occursCheck) InferOf exp # UVarOf m
i0
    m () -> (m () -> m ()) -> m ()
forall a b. a -> (a -> b) -> b
& (m () -> (err -> m ()) -> m ()
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` m () -> err -> m ()
forall a b. a -> b -> a
const (() -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

data BlameResult v e
    = Good (InferOf' e v)
    | Mismatch (InferOf' e v, InferOf' e v)
    deriving (forall x. BlameResult v e -> Rep (BlameResult v e) x)
-> (forall x. Rep (BlameResult v e) x -> BlameResult v e)
-> Generic (BlameResult v e)
forall x. Rep (BlameResult v e) x -> BlameResult v e
forall x. BlameResult v e -> Rep (BlameResult v e) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (v :: HyperType) (e :: AHyperType) x.
Rep (BlameResult v e) x -> BlameResult v e
forall (v :: HyperType) (e :: AHyperType) x.
BlameResult v e -> Rep (BlameResult v e) x
$cto :: forall (v :: HyperType) (e :: AHyperType) x.
Rep (BlameResult v e) x -> BlameResult v e
$cfrom :: forall (v :: HyperType) (e :: AHyperType) x.
BlameResult v e -> Rep (BlameResult v e) x
Generic
makePrisms ''BlameResult
makeCommonInstances [''BlameResult]

finalize ::
    forall a m exp.
    Blame m exp =>
    Ann (a :*: InferResult (UVarOf m) :*: InferResult (UVarOf m)) # exp ->
    m (Ann (a :*: BlameResult (UVarOf m)) # exp)
finalize :: (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
 # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
finalize (Ann (a ('AHyperType exp)
a :*: InferResult InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i0 :*: InferResult InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i1) 'AHyperType exp
:# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
x) =
    Dict (HNodesConstraint exp (Blame m))
-> (HNodesConstraint exp (Blame m) =>
    m (Ann (a :*: BlameResult (UVarOf m)) # exp))
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy (Blame m exp) -> Dict (HNodesConstraint exp (Blame m))
forall (c :: HyperType -> Constraint) (h :: HyperType)
       (proxy :: Constraint -> *).
(Recursive c, HNodes h, c h) =>
proxy (c h) -> Dict (HNodesConstraint h c)
recurse (Proxy (Blame m exp)
forall k (t :: k). Proxy t
Proxy @(Blame m exp))) ((HNodesConstraint exp (Blame m) =>
  m (Ann (a :*: BlameResult (UVarOf m)) # exp))
 -> m (Ann (a :*: BlameResult (UVarOf m)) # exp))
-> (HNodesConstraint exp (Blame m) =>
    m (Ann (a :*: BlameResult (UVarOf m)) # exp))
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
forall a b. (a -> b) -> a -> b
$
    do
        Bool
match <- Proxy exp
-> (InferOf exp # UVarOf m) -> (InferOf exp # UVarOf m) -> m Bool
forall (m :: * -> *) (t :: HyperType).
Blame m t =>
Proxy t
-> (InferOf t # UVarOf m) -> (InferOf t # UVarOf m) -> m Bool
inferOfMatches (Proxy exp
forall k (t :: k). Proxy t
Proxy @exp) InferOf exp # UVarOf m
InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i0 InferOf exp # UVarOf m
InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i1
        let result :: BlameResult (UVarOf m) ('AHyperType exp)
result
                | Bool
match = (InferOf (GetHyperType ('AHyperType exp)) # UVarOf m)
-> BlameResult (UVarOf m) ('AHyperType exp)
forall (v :: HyperType) (e :: AHyperType).
InferOf' e v -> BlameResult v e
Good InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i0
                | Bool
otherwise = (InferOf (GetHyperType ('AHyperType exp)) # UVarOf m,
 InferOf (GetHyperType ('AHyperType exp)) # UVarOf m)
-> BlameResult (UVarOf m) ('AHyperType exp)
forall (v :: HyperType) (e :: AHyperType).
(InferOf' e v, InferOf' e v) -> BlameResult v e
Mismatch (InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i0, InferOf (GetHyperType ('AHyperType exp)) # UVarOf m
i1)
        (forall (n :: HyperType).
 HWitness exp n
 -> (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     # n)
 -> m (Ann (a :*: BlameResult (UVarOf m)) # n))
-> (exp
    # Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
-> m (exp # Ann (a :*: BlameResult (UVarOf m)))
forall (f :: * -> *) (h :: HyperType) (p :: HyperType)
       (q :: HyperType).
(Applicative f, HTraversable h) =>
(forall (n :: HyperType). HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse (Proxy (Blame m)
forall k (t :: k). Proxy t
Proxy @(Blame m) Proxy (Blame m)
-> (Blame m n =>
    (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     # n)
    -> m (Ann (a :*: BlameResult (UVarOf m)) # n))
-> HWitness exp n
-> (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
    # n)
-> m (Ann (a :*: BlameResult (UVarOf m)) # n)
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> Blame m n =>
(Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
 # n)
-> m (Ann (a :*: BlameResult (UVarOf m)) # n)
forall (a :: HyperType) (m :: * -> *) (exp :: HyperType).
Blame m exp =>
(Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
 # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
finalize) exp
# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
'AHyperType exp
:# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
x
            m (exp # Ann (a :*: BlameResult (UVarOf m)))
-> ((exp # Ann (a :*: BlameResult (UVarOf m)))
    -> Ann (a :*: BlameResult (UVarOf m)) # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (:*:) a (BlameResult (UVarOf m)) ('AHyperType exp)
-> ('AHyperType exp :# Ann (a :*: BlameResult (UVarOf m)))
-> Ann (a :*: BlameResult (UVarOf m)) # exp
forall (a :: HyperType) (h :: AHyperType).
a h -> (h :# Ann a) -> Ann a h
Ann (a ('AHyperType exp)
a a ('AHyperType exp)
-> BlameResult (UVarOf m) ('AHyperType exp)
-> (:*:) a (BlameResult (UVarOf m)) ('AHyperType exp)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: BlameResult (UVarOf m) ('AHyperType exp)
result)

-- | Perform Hindley-Milner type inference with prioritised blame for type error,
-- given a prioritisation for the different nodes.
--
-- The purpose of the prioritisation is to place the errors in nodes where
-- the resulting errors will be easier to understand.
--
-- The expected `MonadError` behavior is that catching errors rolls back their state changes
-- (i.e @StateT s (Either e)@ is suitable but @EitherT e (State s)@ is not)
--
-- Gets the top-level type for the term for support of recursive definitions,
-- where the top-level type of the term may be in the scope of the inference monad.
blame ::
    forall priority err m exp a.
    ( Ord priority
    , MonadError err m
    , Blame m exp
    ) =>
    (forall n. a # n -> priority) ->
    InferOf exp # UVarOf m ->
    Ann a # exp ->
    m (Ann (a :*: BlameResult (UVarOf m)) # exp)
blame :: (forall (n :: HyperType). (a # n) -> priority)
-> (InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
blame forall (n :: HyperType). (a # n) -> priority
order InferOf exp # UVarOf m
topLevelType Ann a # exp
e =
    do
        Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
# exp
p <- (InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
forall (m :: * -> *) (exp :: HyperType) (a :: HyperType).
Blame m exp =>
(InferOf exp # UVarOf m)
-> (Ann a # exp)
-> m (Ann
        (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp)
prepare InferOf exp # UVarOf m
topLevelType Ann a # exp
e
        (forall (n :: HyperType).
 HWitness (HFlip Ann exp) n
 -> ((a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
     # n)
 -> [(priority, m ())])
-> (HFlip Ann exp
    # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
-> [(priority, m ())]
forall (h :: HyperType) a (p :: HyperType).
(HFoldable h, Monoid a) =>
(forall (n :: HyperType). HWitness h n -> (p # n) -> a)
-> (h # p) -> a
hfoldMap
            ( Proxy (Blame m)
forall k (t :: k). Proxy t
Proxy @(Blame m) Proxy (Blame m)
-> (Blame m n =>
    HWitness (HFlip Ann exp) n
    -> ((a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
        # n)
    -> [(priority, m ())])
-> HWitness (HFlip Ann exp) n
-> ((a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
    # n)
-> [(priority, m ())]
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => HWitness h n -> r) -> HWitness h n -> r
#*#
                \HWitness (HFlip Ann exp) n
w (a ('AHyperType n)
a :*: InferResult InferOf (GetHyperType ('AHyperType n)) # UVarOf m
i0 :*: InferResult InferOf (GetHyperType ('AHyperType n)) # UVarOf m
i1) ->
                [(a ('AHyperType n) -> priority
forall (n :: HyperType). (a # n) -> priority
order a ('AHyperType n)
a, HWitness (HFlip Ann exp) n
-> (InferOf n # UVarOf m) -> (InferOf n # UVarOf m) -> m ()
forall err (m :: * -> *) (top :: HyperType) (exp :: HyperType).
(MonadError err m, Blame m exp) =>
HWitness top exp
-> (InferOf exp # UVarOf m) -> (InferOf exp # UVarOf m) -> m ()
tryUnify HWitness (HFlip Ann exp) n
w InferOf n # UVarOf m
InferOf (GetHyperType ('AHyperType n)) # UVarOf m
i0 InferOf n # UVarOf m
InferOf (GetHyperType ('AHyperType n)) # UVarOf m
i1)]
            ) (Tagged
  (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
   # exp)
  (Identity
     (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
      # exp))
-> Tagged
     (HFlip Ann exp
      # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
     (Identity
        (HFlip Ann exp
         # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))))
forall (f0 :: HyperType -> HyperType) (x0 :: HyperType)
       (k0 :: HyperType) (f1 :: HyperType -> HyperType) (x1 :: HyperType)
       (k1 :: HyperType).
Iso (HFlip f0 x0 # k0) (HFlip f1 x1 # k1) (f0 k0 # x0) (f1 k1 # x1)
_HFlip (Tagged
   (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
    # exp)
   (Identity
      (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
       # exp))
 -> Tagged
      (HFlip Ann exp
       # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))
      (Identity
         (HFlip Ann exp
          # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m))))))
-> (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
    # exp)
-> HFlip Ann exp
   # (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
forall t b. AReview t b -> b -> t
# Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
# exp
p)
            [(priority, m ())]
-> ([(priority, m ())] -> [(priority, m ())]) -> [(priority, m ())]
forall a b. a -> (a -> b) -> b
& ((priority, m ()) -> priority)
-> [(priority, m ())] -> [(priority, m ())]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (priority, m ()) -> priority
forall a b. (a, b) -> a
fst [(priority, m ())] -> ([(priority, m ())] -> m ()) -> m ()
forall a b. a -> (a -> b) -> b
& ((priority, m ()) -> m ()) -> [(priority, m ())] -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (priority, m ()) -> m ()
forall a b. (a, b) -> b
snd
        (Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
 # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
forall (a :: HyperType) (m :: * -> *) (exp :: HyperType).
Blame m exp =>
(Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
 # exp)
-> m (Ann (a :*: BlameResult (UVarOf m)) # exp)
finalize Ann (a :*: (InferResult (UVarOf m) :*: InferResult (UVarOf m)))
# exp
p