{-# LANGUAGE TemplateHaskell, FlexibleInstances, FlexibleContexts #-} {-# LANGUAGE DerivingVia, UndecidableInstances, GeneralizedNewtypeDeriving #-} module TypeLang where import AST import AST.Class.Has import AST.Class.Unify import AST.Infer import AST.Recurse import AST.Term.FuncType import AST.Term.NamelessScope import AST.Term.Nominal import AST.Term.Row import AST.Term.Scheme import AST.Unify import AST.Unify.Binding import AST.Unify.QuantifiedVar import AST.Unify.Term import Algebra.PartialOrd import Control.Applicative import Control.Lens (ALens') import qualified Control.Lens as Lens import Control.Lens.Operators import Control.Monad.Reader (MonadReader) import Control.Monad.ST.Class (MonadST(..)) import Data.STRef import Data.Set (Set) import Data.String (IsString) import Generic.Data import Generics.Constraints (Constraints, makeDerivings) import GHC.Generics (Generic) import Text.PrettyPrint ((<+>)) import qualified Text.PrettyPrint as Pretty import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens) import Prelude newtype Name = Name String deriving stock Show deriving newtype (Eq, Ord, IsString) data Typ k = TInt | TFun (FuncType Typ k) | TRec (k # Row) | TVar Name | TNom (NominalInst Name Types k) deriving Generic data Row k = REmpty | RExtend (RowExtend Name Typ Row k) | RVar Name deriving Generic data RConstraints = RowConstraints { _rForbiddenFields :: Set Name , _rScope :: ScopeLevel } deriving stock (Eq, Show, Generic) deriving (Semigroup, Monoid) via Generically RConstraints data Types k = Types { _tTyp :: k # Typ , _tRow :: k # Row } deriving Generic data TypeError k = TypError (UnifyError Typ k) | RowError (UnifyError Row k) | QVarNotInScope Name deriving Generic Lens.makePrisms ''Typ Lens.makePrisms ''Row Lens.makePrisms ''TypeError Lens.makeLenses ''RConstraints Lens.makeLenses ''Types makeZipMatch ''Types makeZipMatch ''Typ makeZipMatch ''Row makeKTraversableApplyAndBases ''Types makeKTraversableAndBases ''Typ makeKTraversableAndBases ''Row makeDerivings [''Eq, ''Ord, ''Show] [''Typ, ''Row, ''Types, ''TypeError] makeKHasPlain [''Typ, ''Row] type instance NomVarTypes Typ = Types instance HasNominalInst Name Typ where nominalInst = _TNom instance Pretty Name where pPrint (Name x) = Pretty.text x instance Pretty RConstraints where pPrintPrec _ p (RowConstraints f _) = Pretty.text "Forbidden fields:" <+> pPrint (f ^.. Lens.folded) & maybeParens (p > 10) instance Constraints (Types k) Pretty => Pretty (Types k) where pPrintPrec lvl p (Types typ row) = pPrintPrec lvl p typ <+> pPrintPrec lvl p row instance Constraints (TypeError k) Pretty => Pretty (TypeError k) where pPrintPrec lvl p (TypError x) = pPrintPrec lvl p x pPrintPrec lvl p (RowError x) = pPrintPrec lvl p x pPrintPrec _ _ (QVarNotInScope x) = Pretty.text "quantified type variable not in scope" <+> pPrint x instance Constraints (Typ k) Pretty => Pretty (Typ k) where pPrintPrec _ _ TInt = Pretty.text "Int" pPrintPrec lvl p (TFun x) = pPrintPrec lvl p x pPrintPrec lvl p (TRec x) = pPrintPrec lvl p x pPrintPrec _ _ (TVar s) = pPrint s pPrintPrec _ _ (TNom n) = pPrint n instance Constraints (Types k) Pretty => Pretty (Row k) where pPrintPrec _ _ REmpty = Pretty.text "{}" pPrintPrec lvl p (RExtend (RowExtend k v r)) = pPrintPrec lvl 20 k <+> Pretty.text ":" <+> pPrintPrec lvl 2 v <+> Pretty.text ":*:" <+> pPrintPrec lvl 1 r & maybeParens (p > 1) pPrintPrec _ _ (RVar s) = pPrint s instance HasChild Types Typ where getChild = tTyp instance HasChild Types Row where getChild = tRow instance PartialOrd RConstraints where RowConstraints f0 s0 `leq` RowConstraints f1 s1 = f0 `leq` f1 && s0 `leq` s1 instance TypeConstraints RConstraints where generalizeConstraints = rScope .~ mempty toScopeConstraints = rForbiddenFields .~ mempty instance RowConstraints RConstraints where type RowConstraintsKey RConstraints = Name forbidden = rForbiddenFields instance HasTypeConstraints Typ where type instance TypeConstraintsOf Typ = ScopeLevel verifyConstraints _ TInt = Just TInt verifyConstraints _ (TVar v) = TVar v & Just verifyConstraints c (TFun f) = f & mappedK1 %~ WithConstraint c & TFun & Just verifyConstraints c (TRec r) = WithConstraint (RowConstraints mempty c) r & TRec & Just verifyConstraints c (TNom (NominalInst n (Types t r))) = Types (t & _QVarInstances . traverse %~ WithConstraint c) (r & _QVarInstances . traverse %~ WithConstraint (RowConstraints mempty c)) & NominalInst n & TNom & Just instance HasTypeConstraints Row where type instance TypeConstraintsOf Row = RConstraints verifyConstraints _ REmpty = Just REmpty verifyConstraints _ (RVar x) = RVar x & Just verifyConstraints c (RExtend x) = verifyRowExtendConstraints (^. rScope) c x <&> RExtend type PureInferState = (Tree Types Binding, Tree Types UVar) emptyPureInferState :: PureInferState emptyPureInferState = ( Types emptyBinding emptyBinding , Types (UVar 0) (UVar 0) ) type STNameGen s = Tree Types (Const (STRef s Int)) instance (c Typ, c Row) => Recursively c Typ instance (c Typ, c Row) => Recursively c Row instance RNodes Typ instance RNodes Row instance RTraversable Typ instance RTraversable Row instance RTraversableInferOf Typ instance RTraversableInferOf Row instance (c Typ, c Row, Recursive c) => ITermVarsConstraint c Typ instance (c Typ, c Row, Recursive c) => ITermVarsConstraint c Row instance HasQuantifiedVar Typ where type QVar Typ = Name quantifiedVar = _TVar instance HasQuantifiedVar Row where type QVar Row = Name quantifiedVar = _RVar instance HasFuncType Typ where funcType = _TFun instance HasScopeTypes v Typ a => HasScopeTypes v Typ (a, x) where scopeTypes = Lens._1 . scopeTypes type instance InferOf Typ = ANode Typ type instance InferOf Row = ANode Row instance HasInferredValue Typ where inferredValue = _ANode instance HasInferredValue Row where inferredValue = _ANode instance (Monad m, MonadInstantiate m Typ, MonadInstantiate m Row) => Infer m Typ where inferBody = inferType instance (Monad m, MonadInstantiate m Typ, MonadInstantiate m Row) => Infer m Row where inferBody = inferType rStructureMismatch :: (Unify m Typ, Unify m Row) => (forall c. Unify m c => Tree (UVarOf m) c -> Tree (UVarOf m) c -> m (Tree (UVarOf m) c)) -> Tree (UTermBody (UVarOf m)) Row -> Tree (UTermBody (UVarOf m)) Row -> m () rStructureMismatch match (UTermBody c0 (RExtend r0)) (UTermBody c1 (RExtend r1)) = rowExtendStructureMismatch match _RExtend (c0, r0) (c1, r1) rStructureMismatch _ x y = unifyError (Mismatch (x ^. uBody) (y ^. uBody)) readModifySTRef :: MonadST m => STRef (World m) a -> (a -> a) -> m a readModifySTRef ref func = do old <- readSTRef ref old <$ (writeSTRef ref $! func old) & liftST newStQuantified :: (MonadReader env m, MonadST m, Enum a) => ALens' env (Const (STRef (World m) a) (ast :: Knot)) -> m a newStQuantified l = Lens.view (Lens.cloneLens l . Lens._Wrapped) >>= (`readModifySTRef` succ)