{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeSynonymInstances #-} {-# OPTIONS_GHC -Wno-name-shadowing #-} module Nix.Type.Infer ( Constraint(..) , TypeError(..) , InferError(..) , Subst(..) , inferTop ) where import Control.Applicative import Control.Arrow import Control.Monad.Catch import Control.Monad.Except import Control.Monad.Fail import Control.Monad.Logic import Control.Monad.Reader import Control.Monad.Ref import Control.Monad.ST import Control.Monad.State.Strict import Data.Fix ( cata ) import Data.Foldable import qualified Data.HashMap.Lazy as M import Data.List ( delete , find , nub , intersect , (\\) ) import Data.Map ( Map ) import qualified Data.Map as Map import Data.Maybe ( fromJust ) import qualified Data.Set as Set import Data.Text ( Text ) import Nix.Atoms import Nix.Convert import Nix.Eval ( MonadEval(..) ) import qualified Nix.Eval as Eval import Nix.Expr.Types import Nix.Expr.Types.Annotated import Nix.Fresh import Nix.String import Nix.Scope -- import Nix.Thunk -- import Nix.Thunk.Basic import qualified Nix.Type.Assumption as As import Nix.Type.Env import qualified Nix.Type.Env as Env import Nix.Type.Type import Nix.Utils import Nix.Value.Monad import Nix.Var ------------------------------------------------------------------------------- -- Classes ------------------------------------------------------------------------------- -- | Inference monad newtype InferT s m a = InferT { getInfer :: ReaderT (Set.Set TVar, Scopes (InferT s m) (Judgment s)) (StateT InferState (ExceptT InferError m)) a } deriving ( Functor , Applicative , Alternative , Monad , MonadPlus , MonadFix , MonadReader (Set.Set TVar, Scopes (InferT s m) (Judgment s)) , MonadFail , MonadState InferState , MonadError InferError ) instance MonadTrans (InferT s) where lift = InferT . lift . lift . lift -- instance MonadThunkId m => MonadThunkId (InferT s m) where -- type ThunkId (InferT s m) = ThunkId m -- | Inference state newtype InferState = InferState { count :: Int } -- | Initial inference state initInfer :: InferState initInfer = InferState { count = 0 } data Constraint = EqConst Type Type | ExpInstConst Type Scheme | ImpInstConst Type (Set.Set TVar) Type deriving (Show, Eq, Ord) newtype Subst = Subst (Map TVar Type) deriving (Eq, Ord, Show, Semigroup, Monoid) class Substitutable a where apply :: Subst -> a -> a instance Substitutable TVar where apply (Subst s) a = tv where t = TVar a (TVar tv) = Map.findWithDefault t a s instance Substitutable Type where apply _ ( TCon a ) = TCon a apply s ( TSet b a ) = TSet b (M.map (apply s) a) apply s ( TList a ) = TList (map (apply s) a) apply (Subst s) t@(TVar a ) = Map.findWithDefault t a s apply s ( t1 :~> t2) = apply s t1 :~> apply s t2 apply s ( TMany ts ) = TMany (map (apply s) ts) instance Substitutable Scheme where apply (Subst s) (Forall as t) = Forall as $ apply s' t where s' = Subst $ foldr Map.delete s as instance Substitutable Constraint where apply s (EqConst t1 t2) = EqConst (apply s t1) (apply s t2) apply s (ExpInstConst t sc) = ExpInstConst (apply s t) (apply s sc) apply s (ImpInstConst t1 ms t2) = ImpInstConst (apply s t1) (apply s ms) (apply s t2) instance Substitutable a => Substitutable [a] where apply = map . apply instance (Ord a, Substitutable a) => Substitutable (Set.Set a) where apply = Set.map . apply class FreeTypeVars a where ftv :: a -> Set.Set TVar instance FreeTypeVars Type where ftv TCon{} = Set.empty ftv (TVar a ) = Set.singleton a ftv (TSet _ a ) = Set.unions (map ftv (M.elems a)) ftv (TList a ) = Set.unions (map ftv a) ftv (t1 :~> t2) = ftv t1 `Set.union` ftv t2 ftv (TMany ts ) = Set.unions (map ftv ts) instance FreeTypeVars TVar where ftv = Set.singleton instance FreeTypeVars Scheme where ftv (Forall as t) = ftv t `Set.difference` Set.fromList as instance FreeTypeVars a => FreeTypeVars [a] where ftv = foldr (Set.union . ftv) Set.empty instance (Ord a, FreeTypeVars a) => FreeTypeVars (Set.Set a) where ftv = foldr (Set.union . ftv) Set.empty class ActiveTypeVars a where atv :: a -> Set.Set TVar instance ActiveTypeVars Constraint where atv (EqConst t1 t2) = ftv t1 `Set.union` ftv t2 atv (ImpInstConst t1 ms t2) = ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2) atv (ExpInstConst t s) = ftv t `Set.union` ftv s instance ActiveTypeVars a => ActiveTypeVars [a] where atv = foldr (Set.union . atv) Set.empty data TypeError = UnificationFail Type Type | InfiniteType TVar Type | UnboundVariables [Text] | Ambigious [Constraint] | UnificationMismatch [Type] [Type] deriving (Eq, Show) data InferError = TypeInferenceErrors [TypeError] | TypeInferenceAborted | forall s. Exception s => EvaluationError s typeError :: MonadError InferError m => TypeError -> m () typeError err = throwError $ TypeInferenceErrors [err] deriving instance Show InferError instance Exception InferError instance Semigroup InferError where x <> _ = x instance Monoid InferError where mempty = TypeInferenceAborted mappend = (<>) ------------------------------------------------------------------------------- -- Inference ------------------------------------------------------------------------------- -- | Run the inference monad runInfer' :: MonadInfer m => InferT s m a -> m (Either InferError a) runInfer' = runExceptT . (`evalStateT` initInfer) . (`runReaderT` (Set.empty, emptyScopes)) . getInfer runInfer :: (forall s . InferT s (FreshIdT Int (ST s)) a) -> Either InferError a runInfer m = runST $ do i <- newVar (1 :: Int) runFreshIdT i (runInfer' m) inferType :: forall s m . MonadInfer m => Env -> NExpr -> InferT s m [(Subst, Type)] inferType env ex = do Judgment as cs t <- infer ex let unbounds = Set.fromList (As.keys as) `Set.difference` Set.fromList (Env.keys env) unless (Set.null unbounds) $ typeError $ UnboundVariables (nub (Set.toList unbounds)) let cs' = [ ExpInstConst t s | (x, ss) <- Env.toList env , s <- ss , t <- As.lookup x as ] inferState <- get let eres = (`evalState` inferState) $ runSolver $ do subst <- solve (cs ++ cs') return (subst, subst `apply` t) case eres of Left errs -> throwError $ TypeInferenceErrors errs Right xs -> pure xs -- | Solve for the toplevel type of an expression in a given environment inferExpr :: Env -> NExpr -> Either InferError [Scheme] inferExpr env ex = case runInfer (inferType env ex) of Left err -> Left err Right xs -> Right $ map (\(subst, ty) -> closeOver (subst `apply` ty)) xs -- | Canonicalize and return the polymorphic toplevel type. closeOver :: Type -> Scheme closeOver = normalizeScheme . generalize Set.empty extendMSet :: Monad m => TVar -> InferT s m a -> InferT s m a extendMSet x = InferT . local (first (Set.insert x)) . getInfer letters :: [String] letters = [1 ..] >>= flip replicateM ['a' .. 'z'] freshTVar :: MonadState InferState m => m TVar freshTVar = do s <- get put s { count = count s + 1 } return $ TV (letters !! count s) fresh :: MonadState InferState m => m Type fresh = TVar <$> freshTVar instantiate :: MonadState InferState m => Scheme -> m Type instantiate (Forall as t) = do as' <- mapM (const fresh) as let s = Subst $ Map.fromList $ zip as as' return $ apply s t generalize :: Set.Set TVar -> Type -> Scheme generalize free t = Forall as t where as = Set.toList $ ftv t `Set.difference` free unops :: Type -> NUnaryOp -> [Constraint] unops u1 = \case NNot -> [EqConst u1 (typeFun [typeBool, typeBool])] NNeg -> [ EqConst u1 (TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]]) ] binops :: Type -> NBinaryOp -> [Constraint] binops u1 = \case NApp -> [] -- this is handled separately -- Equality tells you nothing about the types, because any two types are -- allowed. NEq -> [] NNEq -> [] NGt -> inequality NGte -> inequality NLt -> inequality NLte -> inequality NAnd -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])] NOr -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])] NImpl -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])] NConcat -> [EqConst u1 (typeFun [typeList, typeList, typeList])] NUpdate -> [ EqConst u1 (TMany [ typeFun [typeSet, typeSet, typeSet] , typeFun [typeSet, typeNull, typeSet] , typeFun [typeNull, typeSet, typeSet] ] ) ] NPlus -> [ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeInt] , typeFun [typeFloat, typeFloat, typeFloat] , typeFun [typeInt, typeFloat, typeFloat] , typeFun [typeFloat, typeInt, typeFloat] , typeFun [typeString, typeString, typeString] , typeFun [typePath, typePath, typePath] , typeFun [typeString, typeString, typePath] ] ) ] NMinus -> arithmetic NMult -> arithmetic NDiv -> arithmetic where inequality = [ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeBool] , typeFun [typeFloat, typeFloat, typeBool] , typeFun [typeInt, typeFloat, typeBool] , typeFun [typeFloat, typeInt, typeBool] ] ) ] arithmetic = [ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeInt] , typeFun [typeFloat, typeFloat, typeFloat] , typeFun [typeInt, typeFloat, typeFloat] , typeFun [typeFloat, typeInt, typeFloat] ] ) ] liftInfer :: Monad m => m a -> InferT s m a liftInfer = InferT . lift . lift . lift instance MonadRef m => MonadRef (InferT s m) where type Ref (InferT s m) = Ref m newRef x = liftInfer $ newRef x readRef x = liftInfer $ readRef x writeRef x y = liftInfer $ writeRef x y instance MonadAtomicRef m => MonadAtomicRef (InferT s m) where atomicModifyRef x f = liftInfer $ do res <- snd . f <$> readRef x _ <- modifyRef x (fst . f) return res -- newtype JThunkT s m = JThunk (NThunkF (InferT s m) (Judgment s)) instance Monad m => MonadThrow (InferT s m) where throwM = throwError . EvaluationError instance Monad m => MonadCatch (InferT s m) where catch m h = catchError m $ \case EvaluationError e -> maybe (error $ "Exception was not an exception: " ++ show e) h (fromException (toException e)) err -> error $ "Unexpected error: " ++ show err type MonadInfer m = ({- MonadThunkId m,-} MonadVar m, MonadFix m) instance Monad m => MonadValue (Judgment s) (InferT s m) where defer = id demand = flip ($) inform j f = f (pure j) {- instance MonadInfer m => MonadThunk (JThunkT s m) (InferT s m) (Judgment s) where thunk = fmap JThunk . thunk thunkId (JThunk x) = thunkId x queryM (JThunk x) b f = queryM x b f -- If we have a thunk loop, we just don't know the type. force (JThunk t) f = catch (force t f) $ \(_ :: ThunkLoop) -> f =<< Judgment As.empty [] <$> fresh -- If we have a thunk loop, we just don't know the type. forceEff (JThunk t) f = catch (forceEff t f) $ \(_ :: ThunkLoop) -> f =<< Judgment As.empty [] <$> fresh -} instance MonadInfer m => MonadEval (Judgment s) (InferT s m) where freeVariable var = do tv <- fresh return $ Judgment (As.singleton var tv) [] tv synHole var = do tv <- fresh return $ Judgment (As.singleton var tv) [] tv -- If we fail to look up an attribute, we just don't know the type. attrMissing _ _ = Judgment As.empty [] <$> fresh evaledSym _ = pure evalCurPos = return $ Judgment As.empty [] $ TSet False $ M.fromList [("file", typePath), ("line", typeInt), ("col", typeInt)] evalConstant c = return $ Judgment As.empty [] (go c) where go = \case NInt _ -> typeInt NFloat _ -> typeFloat NBool _ -> typeBool NNull -> typeNull evalString = const $ return $ Judgment As.empty [] typeString evalLiteralPath = const $ return $ Judgment As.empty [] typePath evalEnvPath = const $ return $ Judgment As.empty [] typePath evalUnary op (Judgment as1 cs1 t1) = do tv <- fresh return $ Judgment as1 (cs1 ++ unops (t1 :~> tv) op) tv evalBinary op (Judgment as1 cs1 t1) e2 = do Judgment as2 cs2 t2 <- e2 tv <- fresh return $ Judgment (as1 `As.merge` as2) (cs1 ++ cs2 ++ binops (t1 :~> t2 :~> tv) op) tv evalWith = Eval.evalWithAttrSet evalIf (Judgment as1 cs1 t1) t f = do Judgment as2 cs2 t2 <- t Judgment as3 cs3 t3 <- f return $ Judgment (as1 `As.merge` as2 `As.merge` as3) (cs1 ++ cs2 ++ cs3 ++ [EqConst t1 typeBool, EqConst t2 t3]) t2 evalAssert (Judgment as1 cs1 t1) body = do Judgment as2 cs2 t2 <- body return $ Judgment (as1 `As.merge` as2) (cs1 ++ cs2 ++ [EqConst t1 typeBool]) t2 evalApp (Judgment as1 cs1 t1) e2 = do Judgment as2 cs2 t2 <- e2 tv <- fresh return $ Judgment (as1 `As.merge` as2) (cs1 ++ cs2 ++ [EqConst t1 (t2 :~> tv)]) tv evalAbs (Param x) k = do a <- freshTVar let tv = TVar a ((), Judgment as cs t) <- extendMSet a (k (pure (Judgment (As.singleton x tv) [] tv)) (\_ b -> ((), ) <$> b)) return $ Judgment (as `As.remove` x) (cs ++ [ EqConst t' tv | t' <- As.lookup x as ]) (tv :~> t) evalAbs (ParamSet ps variadic _mname) k = do js <- fmap concat $ forM ps $ \(name, _) -> do tv <- fresh pure [(name, tv)] let (env, tys) = (\f -> foldl' f (As.empty, M.empty) js) $ \(as1, t1) (k, t) -> (as1 `As.merge` As.singleton k t, M.insert k t t1) arg = pure $ Judgment env [] (TSet True tys) call = k arg $ \args b -> (args, ) <$> b names = map fst js (args, Judgment as cs t) <- foldr (\(_, TVar a) -> extendMSet a) call js ty <- TSet variadic <$> traverse (inferredType <$>) args return $ Judgment (foldl' As.remove as names) (cs ++ [ EqConst t' (tys M.! x) | x <- names, t' <- As.lookup x as ]) (ty :~> t) evalError = throwError . EvaluationError data Judgment s = Judgment { assumptions :: As.Assumption , typeConstraints :: [Constraint] , inferredType :: Type } deriving Show instance Monad m => FromValue NixString (InferT s m) (Judgment s) where fromValueMay _ = return Nothing fromValue _ = error "Unused" instance MonadInfer m => FromValue (AttrSet (Judgment s), AttrSet SourcePos) (InferT s m) (Judgment s) where fromValueMay (Judgment _ _ (TSet _ xs)) = do let sing _ = Judgment As.empty [] pure $ Just (M.mapWithKey sing xs, M.empty) fromValueMay _ = pure Nothing fromValue = fromValueMay >=> \case Just v -> pure v Nothing -> pure (M.empty, M.empty) instance MonadInfer m => ToValue (AttrSet (Judgment s), AttrSet SourcePos) (InferT s m) (Judgment s) where toValue (xs, _) = Judgment <$> foldrM go As.empty xs <*> (concat <$> traverse (`demand` (pure . typeConstraints)) xs) <*> (TSet True <$> traverse (`demand` (pure . inferredType)) xs) where go x rest = demand x $ \x' -> pure $ As.merge (assumptions x') rest instance MonadInfer m => ToValue [Judgment s] (InferT s m) (Judgment s) where toValue xs = Judgment <$> foldrM go As.empty xs <*> (concat <$> traverse (`demand` (pure . typeConstraints)) xs) <*> (TList <$> traverse (`demand` (pure . inferredType)) xs) where go x rest = demand x $ \x' -> pure $ As.merge (assumptions x') rest instance MonadInfer m => ToValue Bool (InferT s m) (Judgment s) where toValue _ = pure $ Judgment As.empty [] typeBool infer :: MonadInfer m => NExpr -> InferT s m (Judgment s) infer = cata Eval.eval inferTop :: Env -> [(Text, NExpr)] -> Either InferError Env inferTop env [] = Right env inferTop env ((name, ex) : xs) = case inferExpr env ex of Left err -> Left err Right ty -> inferTop (extend env (name, ty)) xs normalizeScheme :: Scheme -> Scheme normalizeScheme (Forall _ body) = Forall (map snd ord) (normtype body) where ord = zip (nub $ fv body) (map TV letters) fv (TVar a ) = [a] fv (a :~> b ) = fv a ++ fv b fv (TCon _ ) = [] fv (TSet _ a) = concatMap fv (M.elems a) fv (TList a ) = concatMap fv a fv (TMany ts) = concatMap fv ts normtype (a :~> b ) = normtype a :~> normtype b normtype (TCon a ) = TCon a normtype (TSet b a) = TSet b (M.map normtype a) normtype (TList a ) = TList (map normtype a) normtype (TMany ts) = TMany (map normtype ts) normtype (TVar a ) = case Prelude.lookup a ord of Just x -> TVar x Nothing -> error "type variable not in signature" ------------------------------------------------------------------------------- -- Constraint Solver ------------------------------------------------------------------------------- newtype Solver m a = Solver (LogicT (StateT [TypeError] m) a) deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadLogic, MonadState [TypeError]) instance MonadTrans Solver where lift = Solver . lift . lift instance Monad m => MonadError TypeError (Solver m) where throwError err = Solver $ lift (modify (err :)) >> mzero catchError _ _ = error "This is never used" runSolver :: Monad m => Solver m a -> m (Either [TypeError] [a]) runSolver (Solver s) = do res <- runStateT (observeAllT s) [] pure $ case res of (x : xs, _ ) -> Right (x : xs) (_ , es) -> Left (nub es) -- | The empty substitution emptySubst :: Subst emptySubst = mempty -- | Compose substitutions compose :: Subst -> Subst -> Subst Subst s1 `compose` Subst s2 = Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1 unifyMany :: Monad m => [Type] -> [Type] -> Solver m Subst unifyMany [] [] = return emptySubst unifyMany (t1 : ts1) (t2 : ts2) = do su1 <- unifies t1 t2 su2 <- unifyMany (apply su1 ts1) (apply su1 ts2) return (su2 `compose` su1) unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2 allSameType :: [Type] -> Bool allSameType [] = True allSameType [_ ] = True allSameType (x : y : ys) = x == y && allSameType (y : ys) unifies :: Monad m => Type -> Type -> Solver m Subst unifies t1 t2 | t1 == t2 = return emptySubst unifies (TVar v) t = v `bind` t unifies t (TVar v) = v `bind` t unifies (TList xs) (TList ys) | allSameType xs && allSameType ys = case (xs, ys) of (x : _, y : _) -> unifies x y _ -> return emptySubst | length xs == length ys = unifyMany xs ys -- We assume that lists of different lengths containing various types cannot -- be unified. unifies t1@(TList _ ) t2@(TList _ ) = throwError $ UnificationFail t1 t2 unifies ( TSet True _) ( TSet True _) = return emptySubst unifies (TSet False b) (TSet True s) | M.keys b `intersect` M.keys s == M.keys s = return emptySubst unifies (TSet True s) (TSet False b) | M.keys b `intersect` M.keys s == M.keys b = return emptySubst unifies (TSet False s) (TSet False b) | null (M.keys b \\ M.keys s) = return emptySubst unifies (t1 :~> t2) (t3 :~> t4) = unifyMany [t1, t2] [t3, t4] unifies (TMany t1s) t2 = considering t1s >>- unifies ?? t2 unifies t1 (TMany t2s) = considering t2s >>- unifies t1 unifies t1 t2 = throwError $ UnificationFail t1 t2 bind :: Monad m => TVar -> Type -> Solver m Subst bind a t | t == TVar a = return emptySubst | occursCheck a t = throwError $ InfiniteType a t | otherwise = return (Subst $ Map.singleton a t) occursCheck :: FreeTypeVars a => TVar -> a -> Bool occursCheck a t = a `Set.member` ftv t nextSolvable :: [Constraint] -> (Constraint, [Constraint]) nextSolvable xs = fromJust (find solvable (chooseOne xs)) where chooseOne xs = [ (x, ys) | x <- xs, let ys = delete x xs ] solvable (EqConst{} , _) = True solvable (ExpInstConst{}, _) = True solvable (ImpInstConst _t1 ms t2, cs) = Set.null ((ftv t2 `Set.difference` ms) `Set.intersection` atv cs) considering :: [a] -> Solver m a considering xs = Solver $ LogicT $ \c n -> foldr c n xs solve :: MonadState InferState m => [Constraint] -> Solver m Subst solve [] = return emptySubst solve cs = solve' (nextSolvable cs) where solve' (EqConst t1 t2, cs) = unifies t1 t2 >>- \su1 -> solve (apply su1 cs) >>- \su2 -> return (su2 `compose` su1) solve' (ImpInstConst t1 ms t2, cs) = solve (ExpInstConst t1 (generalize ms t2) : cs) solve' (ExpInstConst t s, cs) = do s' <- lift $ instantiate s solve (EqConst t s' : cs) instance Monad m => Scoped (Judgment s) (InferT s m) where currentScopes = currentScopesReader clearScopes = clearScopesReader @(InferT s m) @(Judgment s) pushScopes = pushScopesReader lookupVar = lookupVarReader