-- | Generalization of type schemes {-# LANGUAGE UndecidableInstances, TemplateHaskell, FlexibleInstances #-} module AST.Unify.Generalize ( generalize, instantiate , GTerm(..), _GMono, _GPoly, _GBody, KWitness(..) , instantiateWith, instantiateForAll , -- | Exports for @SPECIALIZE@ pragmas. instantiateH ) where import Algebra.PartialOrd (PartialOrd(..)) import AST import AST.Class.Unify (Unify(..), UVarOf, BindingDict(..)) import AST.Class.Traversable import AST.Combinator.Flip import AST.Recurse import AST.TH.Internal.Instances (makeCommonInstances) import AST.Unify.Constraints import AST.Unify.Lookup (semiPruneLookup) import AST.Unify.New import AST.Unify.Occurs (occursError) import AST.Unify.Term (UTerm(..), uBody) import qualified Control.Lens as Lens import Control.Lens.Operators import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Writer (WriterT(..), tell) import Data.Constraint (withDict) import Data.Monoid (All(..)) import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import Prelude.Compat -- | An efficient representation of a type scheme arising from -- generalizing a unification term. Type subexpressions which are -- completely monomoprhic are tagged as such, to avoid redundant -- instantation and unification work data GTerm v ast = GMono (v ast) -- ^ Completely monomoprhic term | GPoly (v ast) -- ^ Points to a quantified variable (instantiation will -- create fresh unification terms) (`AST.Unify.Term.USkolem` -- or `AST.Unify.Term.UResolved`) | GBody (ast # GTerm v) -- ^ Term with some polymorphic parts deriving Generic Lens.makePrisms ''GTerm makeCommonInstances [''GTerm] instance RNodes a => KNodes (Flip GTerm a) where type KNodesConstraint (Flip GTerm a) c = (c a, Recursive c) data KWitness (Flip GTerm a) n = E_Flip_GTerm (KRecWitness a n) {-# INLINE kLiftConstraint #-} kLiftConstraint (E_Flip_GTerm KRecSelf) = const id kLiftConstraint (E_Flip_GTerm (KRecSub c n)) = kLiftConstraintH c n kLiftConstraintH :: forall a c b n r. (RNodes a, KNodesConstraint (Flip GTerm a) c) => KWitness a b -> KRecWitness b n -> Proxy c -> (c n => r) -> r kLiftConstraintH c n = withDict (recurse (Proxy @(RNodes a))) $ withDict (recurse (Proxy @(c a))) $ kLiftConstraint c (Proxy @RNodes) ( kLiftConstraint c (Proxy @c) (kLiftConstraint (E_Flip_GTerm n)) ) instance Recursively KFunctor ast => KFunctor (Flip GTerm ast) where {-# INLINE mapK #-} mapK f = _Flip %~ \case GMono x -> f (E_Flip_GTerm KRecSelf) x & GMono GPoly x -> f (E_Flip_GTerm KRecSelf) x & GPoly GBody x -> withDict (recursively (Proxy @(KFunctor ast))) $ mapK ( \cw -> kLiftConstraint cw (Proxy @(Recursively KFunctor)) $ Lens.from _Flip %~ mapK (f . (\(E_Flip_GTerm nw) -> E_Flip_GTerm (KRecSub cw nw))) ) x & GBody instance Recursively KFoldable ast => KFoldable (Flip GTerm ast) where {-# INLINE foldMapK #-} foldMapK f = \case GMono x -> f (E_Flip_GTerm KRecSelf) x GPoly x -> f (E_Flip_GTerm KRecSelf) x GBody x -> withDict (recursively (Proxy @(KFoldable ast))) $ foldMapK ( \cw -> kLiftConstraint cw (Proxy @(Recursively KFoldable)) $ foldMapK (f . (\(E_Flip_GTerm nw) -> E_Flip_GTerm (KRecSub cw nw))) . (_Flip #) ) x . (^. _Flip) instance RTraversable ast => KTraversable (Flip GTerm ast) where {-# INLINE sequenceK #-} sequenceK (MkFlip fx) = case fx of GMono x -> runContainedK x <&> GMono GPoly x -> runContainedK x <&> GPoly GBody x -> withDict (recurse (Proxy @(RTraversable ast))) $ -- KTraversable will be required when not implied by Recursively traverseK ( Proxy @RTraversable #> Lens.from _Flip sequenceK ) x <&> GBody <&> MkFlip -- | Generalize a unification term pointed by the given variable to a `GTerm`. -- Unification variables that are scoped within the term -- become universally quantified skolems. generalize :: forall m t. Unify m t => Tree (UVarOf m) t -> m (Tree (GTerm (UVarOf m)) t) generalize v0 = do (v1, u) <- semiPruneLookup v0 c <- scopeConstraints case u of UUnbound l | toScopeConstraints l `leq` c -> GPoly v1 <$ -- We set the variable to a skolem, -- so additional unifications after generalization -- (for example hole resumptions where supported) -- cannot unify it with anything. bindVar binding v1 (USkolem (generalizeConstraints l)) USkolem l | toScopeConstraints l `leq` c -> pure (GPoly v1) UTerm t -> withDict (unifyRecursive (Proxy @m) (Proxy @t)) $ do bindVar binding v1 (UResolving t) r <- traverseK (Proxy @(Unify m) #> generalize) (t ^. uBody) r <$ bindVar binding v1 (UTerm t) <&> \b -> if foldMapK (Proxy @(Unify m) #> All . Lens.has _GMono) b ^. Lens._Wrapped then GMono v1 else GBody b UResolving t -> GMono v1 <$ occursError v1 t _ -> pure (GMono v1) {-# INLINE instantiateForAll #-} instantiateForAll :: Unify m t => (TypeConstraintsOf t -> Tree (UTerm (UVarOf m)) t) -> Tree (UVarOf m) t -> WriterT [m ()] m (Tree (UVarOf m) t) instantiateForAll cons x = lookupVar binding x & lift >>= \case USkolem l -> do tell [bindVar binding x (USkolem l)] r <- scopeConstraints <&> (<> l) >>= newVar binding . cons & lift UInstantiated r & bindVar binding x & lift pure r UInstantiated v -> pure v _ -> error "unexpected state at instantiate's forall" -- TODO: Better name? {-# INLINE instantiateH #-} instantiateH :: forall m t. Unify m t => (forall n. TypeConstraintsOf n -> Tree (UTerm (UVarOf m)) n) -> Tree (GTerm (UVarOf m)) t -> WriterT [m ()] m (Tree (UVarOf m) t) instantiateH _ (GMono x) = pure x instantiateH cons (GPoly x) = instantiateForAll cons x instantiateH cons (GBody x) = withDict (unifyRecursive (Proxy @m) (Proxy @t)) $ traverseK (Proxy @(Unify m) #> instantiateH cons) x >>= lift . newTerm {-# INLINE instantiateWith #-} instantiateWith :: forall m t a. Unify m t => m a -> (forall n. TypeConstraintsOf n -> Tree (UTerm (UVarOf m)) n) -> Tree (GTerm (UVarOf m)) t -> m (Tree (UVarOf m) t, a) instantiateWith action cons g = do (r, recover) <- instantiateH cons g & runWriterT action <* sequence_ recover <&> (r, ) -- | Instantiate a generalized type with fresh unification variables -- for the quantified variables {-# INLINE instantiate #-} instantiate :: Unify m t => Tree (GTerm (UVarOf m)) t -> m (Tree (UVarOf m) t) instantiate g = instantiateWith (pure ()) UUnbound g <&> (^. Lens._1)