module Language.PureScript.TypeChecker.Entailment (
entails
) where
import Data.Function (on)
import Data.List
import Data.Maybe (maybeToList)
import Data.Foldable (foldMap)
import qualified Data.Map as M
import Control.Applicative
import Control.Arrow (Arrow(..))
import Control.Monad.Except
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Environment
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Synonyms
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
entails :: Environment -> ModuleName -> [TypeClassDictionaryInScope] -> Constraint -> Bool -> Check Expr
entails env moduleName context = solve (sortedNubBy canonicalizeDictionary (filter filterModule context))
where
sortedNubBy :: (Ord k) => (v -> k) -> [v] -> [v]
sortedNubBy f vs = M.elems (M.fromList (map (f &&& id) vs))
filterModule :: TypeClassDictionaryInScope -> Bool
filterModule (TypeClassDictionaryInScope { tcdName = Qualified (Just mn) _ }) | mn == moduleName = True
filterModule (TypeClassDictionaryInScope { tcdName = Qualified Nothing _ }) = True
filterModule _ = False
solve context' (className, tys) trySuperclasses =
checkOverlaps $ go trySuperclasses className tys
where
go trySuperclasses' className' tys' =
[ mkDictionary (canonicalizeDictionary tcd) args
| tcd <- context'
, className' == tcdClassName tcd
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' (tcdInstanceTypes tcd)
, args <- solveSubgoals subst (tcdDependencies tcd) ] ++
[ SubclassDictionaryValue suDict superclass index
| trySuperclasses'
, (subclassName, (args, _, implies)) <- M.toList (typeClasses env)
, (index, (superclass, suTyArgs)) <- zip [0..] implies
, className' == superclass
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' suTyArgs
, args' <- maybeToList $ mapM ((`lookup` subst) . fst) args
, suDict <- go True subclassName args' ]
solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> [Maybe [DictionaryValue]]
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) = do
dict <- mapM (uncurry (go True) . second (map (replaceAllTypeVars subst))) subgoals
return $ Just dict
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ show superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) subst
guard (all (pairwise (unifiesWith env) . map snd) grps)
return $ map head grps
checkOverlaps :: [DictionaryValue] -> Check Expr
checkOverlaps dicts =
case [ (d1, d2) | d1 <- dicts, d2 <- dicts, d1 `overlapping` d2 ] of
ds@(_ : _) -> throwError . errorMessage $ OverlappingInstances className tys (map fst ds)
_ -> case chooseSimplestDictionaries dicts of
[] -> throwError . errorMessage $ NoInstanceFound className tys
d : _ -> return $ dictionaryValueToValue d
chooseSimplestDictionaries :: [DictionaryValue] -> [DictionaryValue]
chooseSimplestDictionaries ds = case filter isSimpleDictionaryValue ds of
[] -> ds
simple -> simple
isSimpleDictionaryValue SubclassDictionaryValue{} = False
isSimpleDictionaryValue (DependentDictionaryValue _ ds) = all isSimpleDictionaryValue ds
isSimpleDictionaryValue _ = True
overlapping :: DictionaryValue -> DictionaryValue -> Bool
overlapping (LocalDictionaryValue nm1) (LocalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (GlobalDictionaryValue nm1) (GlobalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (DependentDictionaryValue nm1 ds1) (DependentDictionaryValue nm2 ds2)
| nm1 == nm2 = or $ zipWith overlapping ds1 ds2
overlapping SubclassDictionaryValue{} _ = False
overlapping _ SubclassDictionaryValue{} = False
overlapping _ _ = True
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Environment -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ _ (Skolem _ s1 _) (Skolem _ s2 _) | s1 == s2 = Just []
typeHeadsAreEqual _ _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m e (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m e h1 h2
<*> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual m e (SaturatedTypeSynonym name args) t2 = case expandTypeSynonym' e name args of
Left _ -> Nothing
Right t1 -> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual _ _ REmpty REmpty = Just []
typeHeadsAreEqual m e r1@(RCons _ _ _) r2@(RCons _ _ _) =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in (++) <$> foldMap (\(t1, t2) -> typeHeadsAreEqual m e t1 t2) int
<*> go sd1 r1' sd2 r2'
where
go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
go [] REmpty [] REmpty = Just []
go [] (TUnknown _) _ _ = Just []
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = Just []
go [] (Skolem _ s1 _) [] (Skolem _ s2 _) | s1 == s2 = Just []
go sd r [] (TypeVar v) = Just [(v, rowFromList (sd, r))]
go _ _ _ _ = Nothing
typeHeadsAreEqual _ _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs