module Language.PureScript.TypeChecker.Entailment
( InstanceContext
, SolverOptions(..)
, replaceTypeClassDictionaries
, newDictionaries
, entails
) where
import Prelude.Compat
import Protolude (ordNub)
import Control.Arrow (second)
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.State
import Control.Monad.Supply.Class (MonadSupply(..))
import Control.Monad.Writer
import Data.Foldable (for_, fold, toList)
import Data.Function (on)
import Data.Functor (($>))
import Data.List (minimumBy)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Text (Text)
import Language.PureScript.AST
import Language.PureScript.Crash
import Language.PureScript.Environment
import Language.PureScript.Errors
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import Language.PureScript.Label (Label(..))
import Language.PureScript.PSString (PSString, mkString)
import qualified Language.PureScript.Constants as C
data Evidence
= NamedInstance (Qualified Ident)
| WarnInstance Type
| IsSymbolInstance PSString
| CompareSymbolInstance
| AppendSymbolInstance
| UnionInstance
deriving (Show, Eq)
namedInstanceIdentifier :: Evidence -> Maybe (Qualified Ident)
namedInstanceIdentifier (NamedInstance i) = Just i
namedInstanceIdentifier _ = Nothing
type TypeClassDict = TypeClassDictionaryInScope Evidence
type InstanceContext = M.Map (Maybe ModuleName)
(M.Map (Qualified (ProperName 'ClassName))
(M.Map (Qualified Ident) NamedDict))
type Matching a = M.Map Text a
combineContexts :: InstanceContext -> InstanceContext -> InstanceContext
combineContexts = M.unionWith (M.unionWith M.union)
replaceTypeClassDictionaries
:: forall m
. (MonadState CheckState m, MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> Bool
-> Expr
-> m (Expr, [(Ident, InstanceContext, Constraint)])
replaceTypeClassDictionaries shouldGeneralize expr = flip evalStateT M.empty $ do
let loop e = do
(e', solved) <- deferPass e
if getAny solved
then loop e'
else return e'
loop expr >>= generalizePass
where
deferPass :: Expr -> StateT InstanceContext m (Expr, Any)
deferPass = fmap (second fst) . runWriterT . f where
f :: Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
(_, f, _) = everywhereOnValuesTopDownM return (go True) return
generalizePass :: Expr -> StateT InstanceContext m (Expr, [(Ident, InstanceContext, Constraint)])
generalizePass = fmap (second snd) . runWriterT . f where
f :: Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
(_, f, _) = everywhereOnValuesTopDownM return (go False) return
go :: Bool -> Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
go deferErrors (TypeClassDictionary constraint context hints) =
rethrow (addHints hints) $ entails (SolverOptions shouldGeneralize deferErrors) constraint context hints
go _ other = return other
data EntailsResult a
= Solved a TypeClassDict
| Unsolved Constraint
| Deferred
deriving Show
data SolverOptions = SolverOptions
{ solverShouldGeneralize :: Bool
, solverDeferErrors :: Bool
}
entails
:: forall m
. (MonadState CheckState m, MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> SolverOptions
-> Constraint
-> InstanceContext
-> [ErrorMessageHint]
-> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
entails SolverOptions{..} constraint context hints =
solve constraint
where
forClassName :: InstanceContext -> Qualified (ProperName 'ClassName) -> [Type] -> [TypeClassDict]
forClassName _ C.Warn [msg] =
[TypeClassDictionaryInScope (WarnInstance msg) [] C.Warn [msg] Nothing]
forClassName _ C.IsSymbol [TypeLevelString sym] =
[TypeClassDictionaryInScope (IsSymbolInstance sym) [] C.IsSymbol [TypeLevelString sym] Nothing]
forClassName _ C.CompareSymbol [arg0@(TypeLevelString lhs), arg1@(TypeLevelString rhs), _] =
let ordering = case compare lhs rhs of
LT -> C.orderingLT
EQ -> C.orderingEQ
GT -> C.orderingGT
args = [arg0, arg1, TypeConstructor ordering]
in [TypeClassDictionaryInScope CompareSymbolInstance [] C.CompareSymbol args Nothing]
forClassName _ C.AppendSymbol [arg0@(TypeLevelString lhs), arg1@(TypeLevelString rhs), _] =
let args = [arg0, arg1, TypeLevelString (lhs <> rhs)]
in [TypeClassDictionaryInScope AppendSymbolInstance [] C.AppendSymbol args Nothing]
forClassName _ C.Union [l, r, u]
| Just (lOut, rOut, uOut, cst) <- unionRows l r u
= [ TypeClassDictionaryInScope UnionInstance [] C.Union [lOut, rOut, uOut] cst ]
forClassName ctx cn@(Qualified (Just mn) _) tys = concatMap (findDicts ctx cn) (ordNub (Nothing : Just mn : map Just (mapMaybe ctorModules tys)))
forClassName _ _ _ = internalError "forClassName: expected qualified class name"
ctorModules :: Type -> Maybe ModuleName
ctorModules (TypeConstructor (Qualified (Just mn) _)) = Just mn
ctorModules (TypeConstructor (Qualified Nothing _)) = internalError "ctorModules: unqualified type name"
ctorModules (TypeApp ty _) = ctorModules ty
ctorModules (KindedType ty _) = ctorModules ty
ctorModules _ = Nothing
findDicts :: InstanceContext -> Qualified (ProperName 'ClassName) -> Maybe ModuleName -> [TypeClassDict]
findDicts ctx cn = fmap (fmap NamedInstance) . maybe [] M.elems . (>>= M.lookup cn) . flip M.lookup ctx
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
solve :: Constraint -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
solve con = go 0 con
where
go :: Int -> Constraint -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
go work (Constraint className' tys' _) | work > 1000 = throwError . errorMessage $ PossiblyInfiniteInstance className' tys'
go work con'@(Constraint className' tys' conInfo) = WriterT . StateT . (withErrorMessageHint (ErrorSolvingConstraint con') .) . runStateT . runWriterT $ do
latestSubst <- lift . lift $ gets checkSubstitution
let tys'' = map (substituteType latestSubst) tys'
inferred <- lift get
classesInScope <- lift . lift $ gets (typeClasses . checkEnv)
TypeClassData{ typeClassDependencies } <- case M.lookup className' classesInScope of
Nothing -> throwError . errorMessage $ UnknownClass className'
Just tcd -> pure tcd
let instances =
[ (substs, tcd)
| tcd <- forClassName (combineContexts context inferred) className' tys''
, substs <- maybeToList (matches typeClassDependencies tcd tys'')
]
solution <- lift . lift $ unique tys'' instances
case solution of
Solved substs tcd -> do
tell (Any True, mempty)
lift . lift . for_ substs $ pairwiseM unifyTypes
let subst = fmap head substs
currentSubst <- lift . lift $ gets checkSubstitution
subst' <- lift . lift $ withFreshTypes tcd (fmap (substituteType currentSubst) subst)
lift . lift $ zipWithM_ (\t1 t2 -> do
let inferredType = replaceAllTypeVars (M.toList subst') t1
unifyTypes inferredType t2) (tcdInstanceTypes tcd) tys''
currentSubst' <- lift . lift $ gets checkSubstitution
let subst'' = fmap (substituteType currentSubst') subst'
args <- solveSubgoals subst'' (tcdDependencies tcd)
initDict <- lift . lift $ mkDictionary (tcdValue tcd) args
let match = foldr (\(className, index) dict -> subclassDictionaryValue dict className index)
initDict
(tcdPath tcd)
return match
Unsolved unsolved -> do
ident <- freshIdent ("dict" <> runProperName (disqualify (constraintClass unsolved)))
let qident = Qualified Nothing ident
newDicts <- lift . lift $ newDictionaries [] qident unsolved
let newContext = mkContext newDicts
modify (combineContexts newContext)
tell (mempty, [(ident, context, unsolved)])
return (Var qident)
Deferred ->
return (TypeClassDictionary (Constraint className' tys'' conInfo) context hints)
where
withFreshTypes
:: TypeClassDict
-> Matching Type
-> m (Matching Type)
withFreshTypes TypeClassDictionaryInScope{..} subst = do
let onType = everythingOnTypes S.union fromTypeVar
typeVarsInHead = foldMap onType tcdInstanceTypes
<> foldMap (foldMap (foldMap onType . constraintArgs)) tcdDependencies
typeVarsInSubst = S.fromList (M.keys subst)
uninstantiatedTypeVars = typeVarsInHead S.\\ typeVarsInSubst
newSubst <- traverse withFreshType (S.toList uninstantiatedTypeVars)
return (subst <> M.fromList newSubst)
where
fromTypeVar (TypeVar v) = S.singleton v
fromTypeVar _ = S.empty
withFreshType s = do
t <- freshType
return (s, t)
unique :: [Type] -> [(a, TypeClassDict)] -> m (EntailsResult a)
unique tyArgs []
| solverDeferErrors = return Deferred
| solverShouldGeneralize && (null tyArgs || any canBeGeneralized tyArgs) = return (Unsolved (Constraint className' tyArgs conInfo))
| otherwise = throwError . errorMessage $ NoInstanceFound (Constraint className' tyArgs conInfo)
unique _ [(a, dict)] = return $ Solved a dict
unique tyArgs tcds
| pairwiseAny overlapping (map snd tcds) = do
tell . errorMessage $ OverlappingInstances className' tyArgs (tcds >>= (toList . namedInstanceIdentifier . tcdValue . snd))
return $ uncurry Solved (head tcds)
| otherwise = return $ uncurry Solved (minimumBy (compare `on` length . tcdPath . snd) tcds)
canBeGeneralized :: Type -> Bool
canBeGeneralized TUnknown{} = True
canBeGeneralized (KindedType t _) = canBeGeneralized t
canBeGeneralized _ = False
overlapping :: TypeClassDict -> TypeClassDict -> Bool
overlapping TypeClassDictionaryInScope{ tcdPath = _ : _ } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdPath = _ : _ } = False
overlapping TypeClassDictionaryInScope{ tcdDependencies = Nothing } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdDependencies = Nothing } = False
overlapping tcd1 tcd2 = tcdValue tcd1 /= tcdValue tcd2
solveSubgoals :: Matching Type -> Maybe [Constraint] -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) (Maybe [Expr])
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) =
Just <$> traverse (go (work + 1) . mapConstraintArgs (map (replaceAllTypeVars (M.toList subst)))) subgoals
mkDictionary :: Evidence -> Maybe [Expr] -> m Expr
mkDictionary (NamedInstance n) args = return $ foldl App (Var n) (fold args)
mkDictionary UnionInstance (Just [e]) =
return $ App (Abs (VarBinder (Ident C.__unused)) valUndefined) e
mkDictionary UnionInstance _ = return valUndefined
mkDictionary (WarnInstance msg) _ = do
tell . errorMessage $ UserDefinedWarning msg
return valUndefined
mkDictionary (IsSymbolInstance sym) _ =
let fields = [ ("reflectSymbol", Abs (VarBinder (Ident C.__unused)) (Literal (StringLiteral sym))) ] in
return $ TypeClassDictionaryConstructorApp C.IsSymbol (Literal (ObjectLiteral fields))
mkDictionary CompareSymbolInstance _ =
return $ TypeClassDictionaryConstructorApp C.CompareSymbol (Literal (ObjectLiteral []))
mkDictionary AppendSymbolInstance _ =
return $ TypeClassDictionaryConstructorApp C.AppendSymbol (Literal (ObjectLiteral []))
subclassDictionaryValue :: Expr -> Qualified (ProperName 'ClassName) -> Integer -> Expr
subclassDictionaryValue dict className index =
App (Accessor (mkString (superclassName className index)) dict) valUndefined
unionRows :: Type -> Type -> Type -> Maybe (Type, Type, Type, Maybe [Constraint])
unionRows l r _ =
guard canMakeProgress $> (l, r, rowFromList out, cons)
where
(fixed, rest) = rowToList l
rowVar = TypeVar "r"
(canMakeProgress, out, cons) =
case rest of
REmpty -> (True, (fixed, r), Nothing)
_ -> (not (null fixed), (fixed, rowVar), Just [ Constraint C.Union [rest, r, rowVar] Nothing ])
matches :: [FunctionalDependency] -> TypeClassDict -> [Type] -> Maybe (Matching [Type])
matches deps TypeClassDictionaryInScope{..} tys = do
let matched = zipWith typeHeadsAreEqual tys tcdInstanceTypes
guard $ covers matched
let determinedSet = foldMap (S.fromList . fdDetermined) deps
solved = map snd . filter ((`S.notMember` determinedSet) . fst) $ zipWith (\(_, ts) i -> (i, ts)) matched [0..]
verifySubstitution (M.unionsWith (++) solved)
where
covers :: [(Bool, subst)] -> Bool
covers ms = finalSet == S.fromList [0..length ms 1]
where
initialSet :: S.Set Int
initialSet = S.fromList . map snd . filter (fst . fst) $ zip ms [0..]
finalSet :: S.Set Int
finalSet = untilFixedPoint applyAll initialSet
untilFixedPoint :: Eq a => (a -> a) -> a -> a
untilFixedPoint f = go
where
go a | a' == a = a'
| otherwise = go a'
where a' = f a
applyAll :: S.Set Int -> S.Set Int
applyAll s = foldr applyDependency s deps
applyDependency :: FunctionalDependency -> S.Set Int -> S.Set Int
applyDependency FunctionalDependency{..} xs
| S.fromList fdDeterminers `S.isSubsetOf` xs = xs <> S.fromList fdDetermined
| otherwise = xs
typeHeadsAreEqual :: Type -> Type -> (Bool, Matching [Type])
typeHeadsAreEqual (KindedType t1 _) t2 = typeHeadsAreEqual t1 t2
typeHeadsAreEqual t1 (KindedType t2 _) = typeHeadsAreEqual t1 t2
typeHeadsAreEqual (TUnknown u1) (TUnknown u2) | u1 == u2 = (True, M.empty)
typeHeadsAreEqual (Skolem _ s1 _ _) (Skolem _ s2 _ _) | s1 == s2 = (True, M.empty)
typeHeadsAreEqual t (TypeVar v) = (True, M.singleton v [t])
typeHeadsAreEqual (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = (True, M.empty)
typeHeadsAreEqual (TypeLevelString s1) (TypeLevelString s2) | s1 == s2 = (True, M.empty)
typeHeadsAreEqual (TypeApp h1 t1) (TypeApp h2 t2) =
both (typeHeadsAreEqual h1 h2) (typeHeadsAreEqual t1 t2)
typeHeadsAreEqual REmpty REmpty = (True, M.empty)
typeHeadsAreEqual r1@RCons{} r2@RCons{} =
foldr both (uncurry go rest) common
where
(common, rest) = alignRowsWith typeHeadsAreEqual r1 r2
go :: ([(Label, Type)], Type) -> ([(Label, Type)], Type) -> (Bool, Matching [Type])
go (l, KindedType t1 _) (r, t2) = go (l, t1) (r, t2)
go (l, t1) (r, KindedType t2 _) = go (l, t1) (r, t2)
go ([], REmpty) ([], REmpty) = (True, M.empty)
go ([], TUnknown u1) ([], TUnknown u2) | u1 == u2 = (True, M.empty)
go ([], TypeVar v1) ([], TypeVar v2) | v1 == v2 = (True, M.empty)
go ([], Skolem _ sk1 _ _) ([], Skolem _ sk2 _ _) | sk1 == sk2 = (True, M.empty)
go (sd, r) ([], TypeVar v) = (True, M.singleton v [rowFromList (sd, r)])
go _ _ = (False, M.empty)
typeHeadsAreEqual _ _ = (False, M.empty)
both :: (Bool, Matching [Type]) -> (Bool, Matching [Type]) -> (Bool, Matching [Type])
both (b1, m1) (b2, m2) = (b1 && b2, M.unionWith (++) m1 m2)
verifySubstitution :: Matching [Type] -> Maybe (Matching [Type])
verifySubstitution = traverse meet where
meet ts | pairwiseAll typesAreEqual ts = Just ts
| otherwise = Nothing
typesAreEqual :: Type -> Type -> Bool
typesAreEqual (KindedType t1 _) t2 = typesAreEqual t1 t2
typesAreEqual t1 (KindedType t2 _) = typesAreEqual t1 t2
typesAreEqual (TUnknown u1) (TUnknown u2) | u1 == u2 = True
typesAreEqual (Skolem _ s1 _ _) (Skolem _ s2 _ _) = s1 == s2
typesAreEqual (TypeVar v1) (TypeVar v2) = v1 == v2
typesAreEqual (TypeLevelString s1) (TypeLevelString s2) = s1 == s2
typesAreEqual (TypeConstructor c1) (TypeConstructor c2) = c1 == c2
typesAreEqual (TypeApp h1 t1) (TypeApp h2 t2) = typesAreEqual h1 h2 && typesAreEqual t1 t2
typesAreEqual REmpty REmpty = True
typesAreEqual r1 r2 | isRCons r1 || isRCons r2 =
let (common, rest) = alignRowsWith typesAreEqual r1 r2
in and common && uncurry go rest
where
go :: ([(Label, Type)], Type) -> ([(Label, Type)], Type) -> Bool
go (l, KindedType t1 _) (r, t2) = go (l, t1) (r, t2)
go (l, t1) (r, KindedType t2 _) = go (l, t1) (r, t2)
go ([], TUnknown u1) ([], TUnknown u2) | u1 == u2 = True
go ([], Skolem _ s1 _ _) ([], Skolem _ s2 _ _) = s1 == s2
go ([], REmpty) ([], REmpty) = True
go ([], TypeVar v1) ([], TypeVar v2) = v1 == v2
go _ _ = False
typesAreEqual _ _ = False
isRCons :: Type -> Bool
isRCons RCons{} = True
isRCons _ = False
newDictionaries
:: MonadState CheckState m
=> [(Qualified (ProperName 'ClassName), Integer)]
-> Qualified Ident
-> Constraint
-> m [NamedDict]
newDictionaries path name (Constraint className instanceTy _) = do
tcs <- gets (typeClasses . checkEnv)
let TypeClassData{..} = fromMaybe (internalError "newDictionaries: type class lookup failed") $ M.lookup className tcs
supDicts <- join <$> zipWithM (\(Constraint supName supArgs _) index ->
newDictionaries ((supName, index) : path)
name
(Constraint supName (instantiateSuperclass (map fst typeClassArguments) supArgs instanceTy) Nothing)
) typeClassSuperclasses [0..]
return (TypeClassDictionaryInScope name path className instanceTy Nothing : supDicts)
where
instantiateSuperclass :: [Text] -> [Type] -> [Type] -> [Type]
instantiateSuperclass args supArgs tys = map (replaceAllTypeVars (zip args tys)) supArgs
mkContext :: [NamedDict] -> InstanceContext
mkContext = foldr combineContexts M.empty . map fromDict where
fromDict d = M.singleton Nothing (M.singleton (tcdClassName d) (M.singleton (tcdValue d) d))
pairwiseAll :: (a -> a -> Bool) -> [a] -> Bool
pairwiseAll _ [] = True
pairwiseAll _ [_] = True
pairwiseAll p (x : xs) = all (p x) xs && pairwiseAll p xs
pairwiseAny :: (a -> a -> Bool) -> [a] -> Bool
pairwiseAny _ [] = False
pairwiseAny _ [_] = False
pairwiseAny p (x : xs) = any (p x) xs || pairwiseAny p xs
pairwiseM :: Applicative m => (a -> a -> m ()) -> [a] -> m ()
pairwiseM _ [] = pure ()
pairwiseM _ [_] = pure ()
pairwiseM p (x : xs) = traverse (p x) xs *> pairwiseM p xs