{-# LANGUAGE TemplateHaskell, FlexibleInstances, FlexibleContexts #-}

module Hyper.Class.Infer
    ( InferOf
    , Infer(..)
    , InferChild(..), _InferChild
    , InferredChild(..), inType, inRep
    ) where

import qualified Control.Lens as Lens
import           GHC.Generics
import           Hyper
import           Hyper.Class.Unify
import           Hyper.Recurse

import           Hyper.Internal.Prelude

-- | @InferOf e@ is the inference result of @e@.
--
-- Most commonly it is an inferred type, using
--
-- > type instance InferOf MyTerm = ANode MyType
--
-- But it may also be other things, for example:
--
-- * An inferred value (for types inside terms)
-- * An inferred type together with a scope
type family InferOf (t :: HyperType) :: HyperType

-- | A 'HyperType' containing an inferred child node
data InferredChild v h t = InferredChild
    { InferredChild v h t -> h t
_inRep :: !(h t)
        -- ^ Inferred node.
        --
        -- An 'inferBody' implementation needs to place this value in the corresponding child node of the inferred term body
    , InferredChild v h t -> InferOf (GetHyperType t) # v
_inType :: !(InferOf (GetHyperType t) # v)
        -- ^ The inference result for the child node.
        --
        -- An 'inferBody' implementation may use it to perform unifications with it.
    }
makeLenses ''InferredChild

-- | A 'HyperType' containing an inference action.
--
-- The caller may modify the scope before invoking the action via
-- 'Hyper.Class.Infer.Env.localScopeType' or 'Hyper.Infer.ScopeLevel.localLevel'
newtype InferChild m h t =
    InferChild { InferChild m h t -> m (InferredChild (UVarOf m) h t)
inferChild :: m (InferredChild (UVarOf m) h t) }
makePrisms ''InferChild

-- | @Infer m t@ enables 'Hyper.Infer.infer' to perform type-inference for @t@ in the 'Monad' @m@.
--
-- The 'inferContext' method represents the following constraints on @t@:
--
-- * @HNodesConstraint (InferOf t) (Unify m)@ - The child nodes of the inferrence can unify in the @m@ 'Monad'
-- * @HNodesConstraint t (Infer m)@ - @Infer m@ is also available for child nodes
--
-- It replaces context for the 'Infer' class to avoid @UndecidableSuperClasses@.
--
-- Instances usually don't need to implement this method as the default implementation works for them,
-- but infinitely polymorphic trees such as 'Hyper.Type.AST.NamelessScope.Scope' do need to implement the method,
-- because the required context is infinite.
class (Monad m, HFunctor t) => Infer m t where
    -- | Infer the body of an expression given the inference actions for its child nodes.
    inferBody ::
        t # InferChild m h ->
        m (t # h, InferOf t # UVarOf m)
    default inferBody ::
        (Generic1 t, Infer m (Rep1 t), InferOf t ~ InferOf (Rep1 t)) =>
        t # InferChild m h ->
        m (t # h, InferOf t # UVarOf m)
    inferBody =
        ((Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
 -> (t # h, InferOf (Rep1 t) # UVarOf m))
-> m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
-> m (t # h, InferOf (Rep1 t) # UVarOf m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Rep1 t # h) -> Identity (t # h))
-> (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
-> Identity (t # h, InferOf (Rep1 t) # UVarOf m)
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 (((Rep1 t # h) -> Identity (t # h))
 -> (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
 -> Identity (t # h, InferOf (Rep1 t) # UVarOf m))
-> ((Rep1 t # h) -> t # h)
-> (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
-> (t # h, InferOf (Rep1 t) # UVarOf m)
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (Rep1 t # h) -> t # h
forall k (f :: k -> *) (a :: k). Generic1 f => Rep1 f a -> f a
to1) (m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
 -> m (t # h, InferOf (Rep1 t) # UVarOf m))
-> ((t # InferChild m h)
    -> m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m))
-> (t # InferChild m h)
-> m (t # h, InferOf (Rep1 t) # UVarOf m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Rep1 t # InferChild m h)
-> m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody ((Rep1 t # InferChild m h)
 -> m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m))
-> ((t # InferChild m h) -> Rep1 t # InferChild m h)
-> (t # InferChild m h)
-> m (Rep1 t # h, InferOf (Rep1 t) # UVarOf m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t # InferChild m h) -> Rep1 t # InferChild m h
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1

    -- TODO: Putting documentation here causes duplication in the haddock documentation
    inferContext ::
        proxy0 m ->
        proxy1 t ->
        Dict (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m))
    {-# INLINE inferContext #-}
    default inferContext ::
        (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m)) =>
        proxy0 m ->
        proxy1 t ->
        Dict (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m))
    inferContext proxy0 m
_ proxy1 t
_ = Dict
  (HNodesConstraint t (Infer m),
   HNodesConstraint (InferOf t) (UnifyGen m))
forall (a :: Constraint). a => Dict a
Dict

instance Recursive (Infer m) where
    {-# INLINE recurse #-}
    recurse :: proxy (Infer m h) -> Dict (HNodesConstraint h (Infer m))
recurse proxy (Infer m h)
p =
        Dict
  (HNodesConstraint h (Infer m),
   HNodesConstraint (InferOf h) (UnifyGen m))
-> ((HNodesConstraint h (Infer m),
     HNodesConstraint (InferOf h) (UnifyGen m)) =>
    Dict (HNodesConstraint h (Infer m)))
-> Dict (HNodesConstraint h (Infer m))
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy m
-> Proxy h
-> Dict
     (HNodesConstraint h (Infer m),
      HNodesConstraint (InferOf h) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext (Proxy m
forall k (t :: k). Proxy t
Proxy @m) (proxy (Infer m h) -> Proxy h
forall (proxy :: Constraint -> *) (f :: HyperType -> Constraint)
       (h :: HyperType).
proxy (f h) -> Proxy h
proxyArgument proxy (Infer m h)
p)) (HNodesConstraint h (Infer m),
 HNodesConstraint (InferOf h) (UnifyGen m)) =>
Dict (HNodesConstraint h (Infer m))
forall (a :: Constraint). a => Dict a
Dict

type instance InferOf (a :+: _) = InferOf a

instance (InferOf a ~ InferOf b, Infer m a, Infer m b) => Infer m (a :+: b) where
    {-# INLINE inferBody #-}
    inferBody :: ((a :+: b) # InferChild m h)
-> m ((a :+: b) # h, InferOf (a :+: b) # UVarOf m)
inferBody (L1 a ('AHyperType (InferChild m h))
x) = a ('AHyperType (InferChild m h))
-> m (a ('AHyperType h), InferOf a # UVarOf m)
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody a ('AHyperType (InferChild m h))
x m (a ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> ((a ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
    -> ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m))))
-> m ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (a ('AHyperType h) -> Identity ((a :+: b) # h))
-> (a ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> Identity ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 ((a ('AHyperType h) -> Identity ((a :+: b) # h))
 -> (a ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
 -> Identity ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m))))
-> (a ('AHyperType h) -> (a :+: b) # h)
-> (a ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ a ('AHyperType h) -> (a :+: b) # h
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1
    inferBody (R1 b ('AHyperType (InferChild m h))
x) = b ('AHyperType (InferChild m h))
-> m (b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody b ('AHyperType (InferChild m h))
x m (b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> ((b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
    -> ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m))))
-> m ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (b ('AHyperType h) -> Identity ((a :+: b) # h))
-> (b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> Identity ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 ((b ('AHyperType h) -> Identity ((a :+: b) # h))
 -> (b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
 -> Identity ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m))))
-> (b ('AHyperType h) -> (a :+: b) # h)
-> (b ('AHyperType h), InferOf b ('AHyperType (UVarOf m)))
-> ((a :+: b) # h, InferOf b ('AHyperType (UVarOf m)))
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ b ('AHyperType h) -> (a :+: b) # h
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1

    {-# INLINE inferContext #-}
    inferContext :: proxy0 m
-> proxy1 (a :+: b)
-> Dict
     (HNodesConstraint (a :+: b) (Infer m),
      HNodesConstraint (InferOf (a :+: b)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (a :+: b)
_ =
        Dict
  (HNodesConstraint a (Infer m),
   HNodesConstraint (InferOf b) (UnifyGen m))
-> ((HNodesConstraint a (Infer m),
     HNodesConstraint (InferOf b) (UnifyGen m)) =>
    Dict
      ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
       HNodesConstraint (InferOf b) (UnifyGen m)))
-> Dict
     ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
      HNodesConstraint (InferOf b) (UnifyGen m))
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (proxy0 m
-> Proxy a
-> Dict
     (HNodesConstraint a (Infer m),
      HNodesConstraint (InferOf a) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (Proxy a
forall k (t :: k). Proxy t
Proxy @a)) (((HNodesConstraint a (Infer m),
   HNodesConstraint (InferOf b) (UnifyGen m)) =>
  Dict
    ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
     HNodesConstraint (InferOf b) (UnifyGen m)))
 -> Dict
      ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
       HNodesConstraint (InferOf b) (UnifyGen m)))
-> ((HNodesConstraint a (Infer m),
     HNodesConstraint (InferOf b) (UnifyGen m)) =>
    Dict
      ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
       HNodesConstraint (InferOf b) (UnifyGen m)))
-> Dict
     ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
      HNodesConstraint (InferOf b) (UnifyGen m))
forall a b. (a -> b) -> a -> b
$
        Dict
  (HNodesConstraint b (Infer m),
   HNodesConstraint (InferOf b) (UnifyGen m))
-> ((HNodesConstraint b (Infer m),
     HNodesConstraint (InferOf b) (UnifyGen m)) =>
    Dict
      ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
       HNodesConstraint (InferOf b) (UnifyGen m)))
-> Dict
     ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
      HNodesConstraint (InferOf b) (UnifyGen m))
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (proxy0 m
-> Proxy b
-> Dict
     (HNodesConstraint b (Infer m),
      HNodesConstraint (InferOf b) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (Proxy b
forall k (t :: k). Proxy t
Proxy @b)) (HNodesConstraint b (Infer m),
 HNodesConstraint (InferOf b) (UnifyGen m)) =>
Dict
  ((HNodesConstraint a (Infer m), HNodesConstraint b (Infer m)),
   HNodesConstraint (InferOf b) (UnifyGen m))
forall (a :: Constraint). a => Dict a
Dict

type instance InferOf (M1 _ _ h) = InferOf h

instance Infer m h => Infer m (M1 i c h) where
    {-# INLINE inferBody #-}
    inferBody :: (M1 i c h # InferChild m h)
-> m (M1 i c h # h, InferOf (M1 i c h) # UVarOf m)
inferBody (M1 h ('AHyperType (InferChild m h))
x) = h ('AHyperType (InferChild m h)) -> m (h # h, InferOf h # UVarOf m)
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody h ('AHyperType (InferChild m h))
x m (h # h, InferOf h # UVarOf m)
-> ((h # h, InferOf h # UVarOf m)
    -> (M1 i c h # h, InferOf h # UVarOf m))
-> m (M1 i c h # h, InferOf h # UVarOf m)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ((h # h) -> Identity (M1 i c h # h))
-> (h # h, InferOf h # UVarOf m)
-> Identity (M1 i c h # h, InferOf h # UVarOf m)
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 (((h # h) -> Identity (M1 i c h # h))
 -> (h # h, InferOf h # UVarOf m)
 -> Identity (M1 i c h # h, InferOf h # UVarOf m))
-> ((h # h) -> M1 i c h # h)
-> (h # h, InferOf h # UVarOf m)
-> (M1 i c h # h, InferOf h # UVarOf m)
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (h # h) -> M1 i c h # h
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1

    {-# INLINE inferContext #-}
    inferContext :: proxy0 m
-> proxy1 (M1 i c h)
-> Dict
     (HNodesConstraint (M1 i c h) (Infer m),
      HNodesConstraint (InferOf (M1 i c h)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (M1 i c h)
_ = Dict
  (HNodesConstraint h (Infer m),
   HNodesConstraint (InferOf h) (UnifyGen m))
-> ((HNodesConstraint h (Infer m),
     HNodesConstraint (InferOf h) (UnifyGen m)) =>
    Dict
      (HNodesConstraint h (Infer m),
       HNodesConstraint (InferOf h) (UnifyGen m)))
-> Dict
     (HNodesConstraint h (Infer m),
      HNodesConstraint (InferOf h) (UnifyGen m))
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (proxy0 m
-> Proxy h
-> Dict
     (HNodesConstraint h (Infer m),
      HNodesConstraint (InferOf h) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (Proxy h
forall k (t :: k). Proxy t
Proxy @h)) (HNodesConstraint h (Infer m),
 HNodesConstraint (InferOf h) (UnifyGen m)) =>
Dict
  (HNodesConstraint h (Infer m),
   HNodesConstraint (InferOf h) (UnifyGen m))
forall (a :: Constraint). a => Dict a
Dict

type instance InferOf (Rec1 h) = InferOf h

instance Infer m h => Infer m (Rec1 h) where
    {-# INLINE inferBody #-}
    inferBody :: (Rec1 h # InferChild m h)
-> m (Rec1 h # h, InferOf (Rec1 h) # UVarOf m)
inferBody (Rec1 h ('AHyperType (InferChild m h))
x) = h ('AHyperType (InferChild m h)) -> m (h # h, InferOf h # UVarOf m)
forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody h ('AHyperType (InferChild m h))
x m (h # h, InferOf h # UVarOf m)
-> ((h # h, InferOf h # UVarOf m)
    -> (Rec1 h # h, InferOf h # UVarOf m))
-> m (Rec1 h # h, InferOf h # UVarOf m)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ((h # h) -> Identity (Rec1 h # h))
-> (h # h, InferOf h # UVarOf m)
-> Identity (Rec1 h # h, InferOf h # UVarOf m)
forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 (((h # h) -> Identity (Rec1 h # h))
 -> (h # h, InferOf h # UVarOf m)
 -> Identity (Rec1 h # h, InferOf h # UVarOf m))
-> ((h # h) -> Rec1 h # h)
-> (h # h, InferOf h # UVarOf m)
-> (Rec1 h # h, InferOf h # UVarOf m)
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (h # h) -> Rec1 h # h
forall k (f :: k -> *) (p :: k). f p -> Rec1 f p
Rec1

    {-# INLINE inferContext #-}
    inferContext :: proxy0 m
-> proxy1 (Rec1 h)
-> Dict
     (HNodesConstraint (Rec1 h) (Infer m),
      HNodesConstraint (InferOf (Rec1 h)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (Rec1 h)
_ = Dict
  (HNodesConstraint h (Infer m),
   HNodesConstraint (InferOf h) (UnifyGen m))
-> ((HNodesConstraint h (Infer m),
     HNodesConstraint (InferOf h) (UnifyGen m)) =>
    Dict
      (HNodesConstraint h (Infer m),
       HNodesConstraint (InferOf h) (UnifyGen m)))
-> Dict
     (HNodesConstraint h (Infer m),
      HNodesConstraint (InferOf h) (UnifyGen m))
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (proxy0 m
-> Proxy h
-> Dict
     (HNodesConstraint h (Infer m),
      HNodesConstraint (InferOf h) (UnifyGen m))
forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (Proxy h
forall k (t :: k). Proxy t
Proxy @h)) (HNodesConstraint h (Infer m),
 HNodesConstraint (InferOf h) (UnifyGen m)) =>
Dict
  (HNodesConstraint h (Infer m),
   HNodesConstraint (InferOf h) (UnifyGen m))
forall (a :: Constraint). a => Dict a
Dict