{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} -- | -- Module : Jikka.Core.Convert.TypeInfer -- Description : does type inference. / 型推論を行います。 -- Copyright : (c) Kimiyuki Onaka, 2021 -- License : Apache License 2.0 -- Maintainer : kimiyuki95@gmail.com -- Stability : experimental -- Portability : portable module Jikka.Core.Convert.TypeInfer ( run, runExpr, runRule, -- * internal types and functions Equation (..), formularizeProgram, sortEquations, mergeAssertions, Subst (..), subst, substDefault, solveEquations, substProgram, ) where import Control.Arrow (second) import Control.Monad.State.Strict import Control.Monad.Writer.Strict (MonadWriter, censor, execWriterT, tell) import qualified Data.Map.Strict as M import Data.Monoid (Dual (..)) import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Core.Format (formatExpr, formatToplevelExpr, formatType) import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.FreeVars import Jikka.Core.Language.Lint import Jikka.Core.Language.NameCheck (namecheckExpr) import Jikka.Core.Language.TypeCheck (literalToType, typecheckExpr) import Jikka.Core.Language.Util data Hint = VarHint VarName | ExprHint Expr | ToplevelExprHint ToplevelExpr deriving (Eq, Ord, Show, Read) data Equation = TypeEquation Type Type [Hint] | TypeAssertion VarName Type deriving (Eq, Ord, Show, Read) type Eqns = Dual [Equation] consHint :: Hint -> Equation -> Equation consHint hint = \case TypeEquation t1 t2 hints -> TypeEquation t1 t2 (hint : hints) TypeAssertion x t -> TypeAssertion x t wrapHint :: MonadWriter Eqns m => Hint -> m a -> m a wrapHint hint = censor (fmap (map (consHint hint))) wrapErrorFromHint :: MonadError Error m => Hint -> m a -> m a wrapErrorFromHint = \case VarHint x -> wrapError' $ "around variable " ++ formatVarName x ExprHint e -> wrapError' $ "around expr " ++ summarize (formatExpr e) ToplevelExprHint e -> wrapError' $ "around toplevel expr " ++ summarize (formatToplevelExpr e) where summarize s = case lines s of (s : _ : _) -> s ++ " ..." _ -> s wrapErrorFromHints :: MonadError Error m => [Hint] -> m a -> m a wrapErrorFromHints hints = foldr (\hint f -> wrapErrorFromHint hint . f) id hints formularizeType :: MonadWriter Eqns m => Type -> Type -> m () formularizeType t1 t2 = tell $ Dual [TypeEquation t1 t2 []] formularizeVarName :: MonadWriter Eqns m => VarName -> Type -> m () formularizeVarName x t = tell $ Dual [TypeAssertion x t] formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> m Type formularizeExpr e = wrapHint (ExprHint e) $ case e of Var x -> do t <- genType formularizeVarName x t return t Lit lit -> case lit of LitBuiltin (Proj _) [] -> genType -- Proj may have a empty list. _ -> literalToType lit App f e -> do ret <- genType t <- formularizeExpr e formularizeExpr' f (FunTy t ret) return ret Lam x t body -> do formularizeVarName x t ret <- formularizeExpr body return $ FunTy t ret Let x t e1 e2 -> do formularizeVarName x t formularizeExpr' e1 t formularizeExpr e2 Assert e1 e2 -> do formularizeExpr' e1 BoolTy formularizeExpr e2 formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> Type -> m () formularizeExpr' e t = do t' <- formularizeExpr e wrapHint (ExprHint e) $ do formularizeType t t' formularizeToplevelExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => ToplevelExpr -> m Type formularizeToplevelExpr e = wrapHint (ToplevelExprHint e) $ case e of ResultExpr e -> formularizeExpr e ToplevelLet x t e cont -> do formularizeVarName x t formularizeExpr' e t formularizeToplevelExpr cont ToplevelLetRec f args ret body cont -> do formularizeVarName f (curryFunTy (map snd args) ret) mapM_ (uncurry formularizeVarName) args formularizeExpr' body ret formularizeToplevelExpr cont ToplevelAssert e cont -> do formularizeExpr' e BoolTy formularizeToplevelExpr cont formularizeProgram :: (MonadAlpha m, MonadError Error m) => Program -> m [Equation] formularizeProgram prog = getDual <$> execWriterT (formularizeToplevelExpr prog) sortEquations :: [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)]) sortEquations = go [] [] where go eqns' assertions [] = (eqns', assertions) go eqns' assertions (eqn : eqns) = case eqn of TypeEquation t1 t2 hints -> go ((t1, t2, hints) : eqns') assertions eqns TypeAssertion x t -> go eqns' ((x, t) : assertions) eqns mergeAssertions :: [(VarName, Type)] -> [(Type, Type, [Hint])] mergeAssertions = go M.empty [] where go _ eqns [] = eqns go gamma eqns ((x, t) : assertions) = case M.lookup x gamma of Nothing -> go (M.insert x t gamma) eqns assertions Just t' -> go gamma ((t, t', [VarHint x]) : eqns) assertions -- | `Subst` is type substituion. It's a mapping from type variables to their actual types. newtype Subst = Subst {unSubst :: M.Map TypeName Type} deriving (Eq, Ord, Show, Read) subst :: Subst -> Type -> Type subst sigma = \case VarTy x -> case M.lookup x (unSubst sigma) of Nothing -> VarTy x Just t -> subst sigma t IntTy -> IntTy BoolTy -> BoolTy ListTy t -> ListTy (subst sigma t) TupleTy ts -> TupleTy (map (subst sigma) ts) FunTy t ret -> FunTy (subst sigma t) (subst sigma ret) DataStructureTy ds -> DataStructureTy ds unifyTyVar :: (MonadState Subst m, MonadError Error m) => TypeName -> Type -> m () unifyTyVar x t = if x `elem` freeTyVars t then throwInternalError $ "looped type equation " ++ formatTypeName x ++ " = " ++ formatType t else do modify' (Subst . M.insert x t . unSubst) -- This doesn't introduce the loop. unifyType :: (MonadState Subst m, MonadError Error m) => Type -> Type -> m () unifyType t1 t2 = wrapError' ("failed to unify " ++ formatType t1 ++ " and " ++ formatType t2) $ do sigma <- get t1 <- return $ subst sigma t1 -- shadowing t2 <- return $ subst sigma t2 -- shadowing case (t1, t2) of _ | t1 == t2 -> return () (VarTy x1, _) -> do unifyTyVar x1 t2 (_, VarTy x2) -> do unifyTyVar x2 t1 (ListTy t1, ListTy t2) -> do unifyType t1 t2 (TupleTy ts1, TupleTy ts2) -> do if length ts1 == length ts2 then mapM_ (uncurry unifyType) (zip ts1 ts2) else throwInternalError $ "different type ctors " ++ formatType t1 ++ " and " ++ formatType t2 (FunTy t1 ret1, FunTy t2 ret2) -> do unifyType t1 t2 unifyType ret1 ret2 _ -> throwInternalError $ "different type ctors " ++ formatType t1 ++ " and " ++ formatType t2 solveEquations :: MonadError Error m => [(Type, Type, [Hint])] -> m Subst solveEquations eqns = wrapError' "failed to solve type equations" $ do (`execStateT` Subst M.empty) $ do forM_ eqns $ \(t1, t2, hints) -> do wrapErrorFromHints hints $ do unifyType t1 t2 -- | `substDefault` replaces all undetermined type variables with the given default type. substDefault :: Type -> Type -> Type substDefault t0 = \case VarTy _ -> TupleTy [] IntTy -> IntTy BoolTy -> BoolTy ListTy t -> ListTy (substDefault t0 t) TupleTy ts -> TupleTy (map (substDefault t0) ts) FunTy t ret -> FunTy (substDefault t0 t) (substDefault t0 ret) DataStructureTy ds -> DataStructureTy ds subst' :: Maybe Type -> Subst -> Type -> Type subst' t0 sigma = maybe id substDefault t0 . subst sigma fixProj :: MonadError Error m => [(VarName, Type)] -> Expr -> m Expr fixProj env = \case Proj' [] i e -> do -- fix Proj with a empty list t <- typecheckExpr env e case t of TupleTy ts -> return $ Proj' ts i e _ -> throwInternalError $ "type of argument of proj must be a tuple: " ++ formatType t e -> return e substProgram :: MonadError Error m => Maybe Type -> Subst -> Program -> m Program substProgram t0 sigma = mapExprProgramM (mapSubExprM fixProj) . mapTypeProgram (subst' t0 sigma) substExpr :: MonadError Error m => Maybe Type -> Subst -> [(VarName, Type)] -> Expr -> m Expr substExpr t0 sigma env = mapSubExprM fixProj env . mapTypeExpr (subst' t0 sigma) -- | `run` does type inference. -- -- * This assumes that program has no name conflicts. -- -- Before: -- -- > let f = fun y -> y -- > in let x = 1 -- > in f(x + x) -- -- After: -- -- > let f: int -> int = fun y: int -> y -- > in let x: int = 1 -- > in f(x + x) run :: (MonadAlpha m, MonadError Error m) => Program -> m Program run prog = wrapError' "Jikka.Core.Convert.TypeInfer.run" $ do precondition $ do ensureAlphaConverted prog eqns <- formularizeProgram prog let (eqns', assertions) = sortEquations eqns let eqns'' = mergeAssertions assertions sigma <- solveEquations (eqns' ++ eqns'') let t0 = Just UnitTy prog <- substProgram t0 sigma prog postcondition $ do ensureAlphaConverted prog ensureWellTyped prog return prog runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr runExpr env e = wrapError' "Jikka.Core.Convert.TypeInfer.runExpr" $ do precondition $ do namecheckExpr env e eqns <- getDual <$> execWriterT (formularizeExpr e) let (eqns', assertions) = sortEquations eqns let eqns'' = mergeAssertions (env ++ assertions) sigma <- solveEquations (eqns' ++ eqns'') let t0 = Nothing -- don't use substDefault env <- return $ map (second (subst' t0 sigma)) env e <- substExpr t0 sigma env e postcondition $ do namecheckExpr env e typecheckExpr env e return e runRule :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> Expr -> m ([(VarName, Type)], Expr, Expr) runRule args e1 e2 = wrapError' "Jikka.Core.Convert.TypeInfer.runRule" $ do precondition $ do -- Underscores are allowed for names, so we don't use namecheckExpr here. return () eqns <- (getDual <$>) . execWriterT $ do t <- formularizeExpr e1 formularizeExpr' e2 t let (eqns', assertions) = sortEquations eqns let eqns'' = mergeAssertions assertions sigma <- solveEquations (eqns' ++ eqns'') args <- return $ map (second (subst sigma)) args -- don't use substDefault e1 <- return $ mapTypeExpr (subst sigma) e1 -- don't use substDefault e2 <- return $ mapTypeExpr (subst sigma) e2 -- don't use substDefault postcondition $ do -- Underscores are allowed for names, so we don't use namecheckExpr here. -- Type variables can remain, so we don't use typecheckExpr here. return () return (args, e1, e2)