module Language.PureScript.TypeChecker.Entailment (Context, replaceTypeClassDictionaries) where
import Prelude ()
import Prelude.Compat
import Data.Function (on)
import Data.List (minimumBy, sortBy, groupBy)
import Data.Maybe (maybeToList, mapMaybe)
import qualified Data.Map as M
import Control.Arrow (Arrow(..))
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.Supply.Class (MonadSupply(..))
import Language.PureScript.Crash
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
type Context = M.Map (Maybe ModuleName)
(M.Map (Qualified (ProperName 'ClassName))
(M.Map (Qualified Ident)
TypeClassDictionaryInScope))
combineContexts :: Context -> Context -> Context
combineContexts = M.unionWith (M.unionWith M.union)
replaceTypeClassDictionaries
:: (MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> Bool
-> ModuleName
-> Expr
-> m (Expr, [(Ident, Constraint)])
replaceTypeClassDictionaries shouldGeneralize mn =
let (_, f, _) = everywhereOnValuesTopDownM return (WriterT . go) return
in flip evalStateT M.empty . runWriterT . f
where
go (TypeClassDictionary constraint dicts) = entails shouldGeneralize mn dicts constraint
go other = return (other, [])
entails
:: forall m
. (MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> Bool
-> ModuleName
-> Context
-> Constraint
-> StateT Context m (Expr, [(Ident, Constraint)])
entails shouldGeneralize moduleName context = solve
where
forClassName :: Context -> Qualified (ProperName 'ClassName) -> [Type] -> [TypeClassDictionaryInScope]
forClassName ctx cn@(Qualified (Just mn) _) tys = concatMap (findDicts ctx cn) (Nothing : Just mn : map Just (mapMaybe ctorModules tys))
forClassName _ _ _ = internalError "forClassName: expected qualified class name"
ctorModules :: Type -> Maybe ModuleName
ctorModules (TypeConstructor (Qualified (Just mn) _)) = Just mn
ctorModules (TypeConstructor (Qualified Nothing _)) = internalError "ctorModules: unqualified type name"
ctorModules (TypeApp ty _) = ctorModules ty
ctorModules _ = Nothing
findDicts :: Context -> Qualified (ProperName 'ClassName) -> Maybe ModuleName -> [TypeClassDictionaryInScope]
findDicts ctx cn = maybe [] M.elems . (>>= M.lookup cn) . flip M.lookup ctx
solve :: Constraint -> StateT Context m (Expr, [(Ident, Constraint)])
solve (className, tys) = do
(dict, unsolved) <- go 0 className tys
return (dictionaryValueToValue dict, unsolved)
where
go :: Int -> Qualified (ProperName 'ClassName) -> [Type] -> StateT Context m (DictionaryValue, [(Ident, Constraint)])
go work className' tys' | work > 1000 = throwError . errorMessage $ PossiblyInfiniteInstance className' tys'
go work className' tys' = do
inferred <- get
let instances = do
tcd <- forClassName (combineContexts context inferred) className' tys'
subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName) tys' (tcdInstanceTypes tcd)
return (subst, tcd)
solution <- lift $ unique instances
case solution of
Left (subst, tcd) -> do
(args, unsolved) <- solveSubgoals subst (tcdDependencies tcd)
let match = foldr (\(superclassName, index) dict -> SubclassDictionaryValue dict superclassName index)
(mkDictionary (tcdName tcd) args)
(tcdPath tcd)
return (match, unsolved)
Right unsolved@(unsolvedClassName@(Qualified _ pn), unsolvedTys) -> do
ident <- freshIdent ("dict" ++ runProperName pn)
let qident = Qualified Nothing ident
let newDict = TypeClassDictionaryInScope qident [] unsolvedClassName unsolvedTys Nothing
newContext = M.singleton Nothing (M.singleton unsolvedClassName (M.singleton qident newDict))
modify (combineContexts newContext)
return (LocalDictionaryValue qident, [(ident, unsolved)])
where
unique :: [(a, TypeClassDictionaryInScope)] -> m (Either (a, TypeClassDictionaryInScope) Constraint)
unique [] | shouldGeneralize && all canBeGeneralized tys' = return $ Right (className, tys)
| otherwise = throwError . errorMessage $ NoInstanceFound className' tys'
unique [a] = return $ Left a
unique tcds | pairwise overlapping (map snd tcds) = do
tell . errorMessage $ OverlappingInstances className' tys' (map (tcdName . snd) tcds)
return $ Left (head tcds)
| otherwise = return $ Left (minimumBy (compare `on` length . tcdPath . snd) tcds)
canBeGeneralized :: Type -> Bool
canBeGeneralized TUnknown{} = True
canBeGeneralized Skolem{} = True
canBeGeneralized _ = False
overlapping :: TypeClassDictionaryInScope -> TypeClassDictionaryInScope -> Bool
overlapping TypeClassDictionaryInScope{ tcdPath = _ : _ } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdPath = _ : _ } = False
overlapping TypeClassDictionaryInScope{ tcdDependencies = Nothing } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdDependencies = Nothing } = False
overlapping tcd1 tcd2 = tcdName tcd1 /= tcdName tcd2
solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> StateT Context m (Maybe [DictionaryValue], [(Ident, Constraint)])
solveSubgoals _ Nothing = return (Nothing, [])
solveSubgoals subst (Just subgoals) = do
zipped <- traverse (uncurry (go (work + 1)) . second (map (replaceAllTypeVars subst))) subgoals
let (dicts, unsolved) = unzip zipped
return (Just dicts, concat unsolved)
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ showQualified runProperName superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) . sortBy (compare `on` fst) $ subst
guard (all (pairwise unifiesWith . map snd) grps)
return $ map head grps
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ (TUnknown u1) (TUnknown u2) | u1 == u2 = Just []
typeHeadsAreEqual _ (Skolem _ s1 _ _) (Skolem _ s2 _ _) | s1 == s2 = Just []
typeHeadsAreEqual _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m h1 h2
<*> typeHeadsAreEqual m t1 t2
typeHeadsAreEqual _ REmpty REmpty = Just []
typeHeadsAreEqual m r1@RCons{} r2@RCons{} =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in (++) <$> foldMap (uncurry (typeHeadsAreEqual m)) int
<*> go sd1 r1' sd2 r2'
where
go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
go [] REmpty [] REmpty = Just []
go [] (TUnknown _) _ _ = Just []
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = Just []
go [] (Skolem _ s1 _ _) [] (Skolem _ s2 _ _) | s1 == s2 = Just []
go sd r [] (TypeVar v) = Just [(v, rowFromList (sd, r))]
go _ _ _ _ = Nothing
typeHeadsAreEqual _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs