module Language.PureScript.TypeChecker.Entailment (entails) where
import Prelude ()
import Prelude.Compat
import Data.Function (on)
import Data.List (minimumBy, sortBy, groupBy)
import Data.Maybe (maybeToList, mapMaybe)
import qualified Data.Map as M
import Control.Arrow (Arrow(..))
import Control.Monad.State
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.Writer.Class (MonadWriter(..))
import Language.PureScript.Crash
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
entails
:: forall m
. (Functor m, Applicative m, MonadError MultipleErrors m, MonadWriter MultipleErrors m)
=> ModuleName
-> M.Map (Maybe ModuleName) (M.Map (Qualified (ProperName 'ClassName)) (M.Map (Qualified Ident) TypeClassDictionaryInScope))
-> Constraint
-> m Expr
entails moduleName context = solve
where
forClassName :: Qualified (ProperName 'ClassName) -> [Type] -> [TypeClassDictionaryInScope]
forClassName cn@(Qualified (Just mn) _) tys = concatMap (findDicts cn) (Nothing : Just mn : map Just (mapMaybe ctorModules tys))
forClassName _ _ = internalError "forClassName: expected qualified class name"
ctorModules :: Type -> Maybe ModuleName
ctorModules (TypeConstructor (Qualified (Just mn) _)) = Just mn
ctorModules (TypeConstructor (Qualified Nothing _)) = internalError "ctorModules: unqualified type name"
ctorModules (TypeApp ty _) = ctorModules ty
ctorModules _ = Nothing
findDicts :: Qualified (ProperName 'ClassName) -> Maybe ModuleName -> [TypeClassDictionaryInScope]
findDicts cn = maybe [] M.elems . (>>= M.lookup cn) . flip M.lookup context
solve :: Constraint -> m Expr
solve (className, tys) = do
dict <- go 0 className tys
return $ dictionaryValueToValue dict
where
go :: Int -> Qualified (ProperName 'ClassName) -> [Type] -> m DictionaryValue
go work className' tys' | work > 1000 = throwError . errorMessage $ PossiblyInfiniteInstance className' tys'
go work className' tys' = do
let instances = do
tcd <- forClassName className' tys'
subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName) tys' (tcdInstanceTypes tcd)
return (subst, tcd)
(subst, tcd) <- unique instances
args <- solveSubgoals subst (tcdDependencies tcd)
return $ foldr (\(superclassName, index) dict -> SubclassDictionaryValue dict superclassName index)
(mkDictionary (tcdName tcd) args)
(tcdPath tcd)
where
unique :: [(a, TypeClassDictionaryInScope)] -> m (a, TypeClassDictionaryInScope)
unique [] = throwError . errorMessage $ NoInstanceFound className' tys'
unique [a] = return a
unique tcds | pairwise overlapping (map snd tcds) = do
tell . errorMessage $ OverlappingInstances className' tys' (map (tcdName . snd) tcds)
return (head tcds)
| otherwise = return (minimumBy (compare `on` length . tcdPath . snd) tcds)
overlapping :: TypeClassDictionaryInScope -> TypeClassDictionaryInScope -> Bool
overlapping TypeClassDictionaryInScope{ tcdPath = _ : _ } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdPath = _ : _ } = False
overlapping TypeClassDictionaryInScope{ tcdDependencies = Nothing } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdDependencies = Nothing } = False
overlapping tcd1 tcd2 = tcdName tcd1 /= tcdName tcd2
solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> m (Maybe [DictionaryValue])
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) = do
dict <- traverse (uncurry (go (work + 1)) . second (map (replaceAllTypeVars subst))) subgoals
return $ Just dict
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ showQualified runProperName superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) . sortBy (compare `on` fst) $ subst
guard (all (pairwise unifiesWith . map snd) grps)
return $ map head grps
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ (Skolem _ s1 _ _) (Skolem _ s2 _ _) | s1 == s2 = Just []
typeHeadsAreEqual _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m h1 h2
<*> typeHeadsAreEqual m t1 t2
typeHeadsAreEqual _ REmpty REmpty = Just []
typeHeadsAreEqual m r1@RCons{} r2@RCons{} =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in (++) <$> foldMap (uncurry (typeHeadsAreEqual m)) int
<*> go sd1 r1' sd2 r2'
where
go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
go [] REmpty [] REmpty = Just []
go [] (TUnknown _) _ _ = Just []
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = Just []
go [] (Skolem _ s1 _ _) [] (Skolem _ s2 _ _) | s1 == s2 = Just []
go sd r [] (TypeVar v) = Just [(v, rowFromList (sd, r))]
go _ _ _ _ = Nothing
typeHeadsAreEqual _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs