module Language.PureScript.TypeChecker.Entailment (
entails
) where
import Data.Function (on)
import Data.List
import Data.Maybe (maybeToList)
import Data.Foldable (foldMap)
import qualified Data.Map as M
import Control.Applicative
import Control.Arrow (Arrow(..))
import Control.Monad.Except
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Environment
import Language.PureScript.Names
import Language.PureScript.Pretty
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Synonyms
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
data DictionaryValue
= LocalDictionaryValue (Qualified Ident)
| GlobalDictionaryValue (Qualified Ident)
| DependentDictionaryValue (Qualified Ident) [DictionaryValue]
| SubclassDictionaryValue DictionaryValue (Qualified ProperName) Integer
deriving (Show, Ord, Eq)
entails :: Environment -> ModuleName -> [TypeClassDictionaryInScope] -> Constraint -> Bool -> Check Expr
entails env moduleName context = solve (sortedNubBy canonicalizeDictionary (filter filterModule context))
where
sortedNubBy :: (Ord k) => (v -> k) -> [v] -> [v]
sortedNubBy f vs = M.elems (M.fromList (map (f &&& id) vs))
filterModule :: TypeClassDictionaryInScope -> Bool
filterModule (TypeClassDictionaryInScope { tcdName = Qualified (Just mn) _ }) | mn == moduleName = True
filterModule (TypeClassDictionaryInScope { tcdName = Qualified Nothing _ }) = True
filterModule _ = False
solve context' (className, tys) trySuperclasses =
checkOverlaps $ go trySuperclasses className tys
where
go trySuperclasses' className' tys' =
[ mkDictionary (canonicalizeDictionary tcd) args
| tcd <- context'
, className' == tcdClassName tcd
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' (tcdInstanceTypes tcd)
, args <- solveSubgoals subst (tcdDependencies tcd) ] ++
[ SubclassDictionaryValue suDict superclass index
| trySuperclasses'
, (subclassName, (args, _, implies)) <- M.toList (typeClasses env)
, (index, (superclass, suTyArgs)) <- zip [0..] implies
, className' == superclass
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' suTyArgs
, args' <- maybeToList $ mapM ((`lookup` subst) . fst) args
, suDict <- go True subclassName args' ]
solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> [Maybe [DictionaryValue]]
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) = do
dict <- mapM (uncurry (go True) . second (map (replaceAllTypeVars subst))) subgoals
return $ Just dict
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ show superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) subst
guard (all (pairwise (unifiesWith env) . map snd) grps)
return $ map head grps
checkOverlaps :: [DictionaryValue] -> Check Expr
checkOverlaps dicts =
case [ (d1, d2) | d1 <- dicts, d2 <- dicts, d1 `overlapping` d2 ] of
(d1, d2) : _ -> throwError . strMsg $ unlines
[ "Overlapping instances found for " ++ show className ++ " " ++ unwords (map prettyPrintType tys) ++ "."
, "For example:"
, prettyPrintDictionaryValue d1
, "and:"
, prettyPrintDictionaryValue d2
]
_ -> case chooseSimplestDictionaries dicts of
[] -> throwError . strMsg $
"No instance found for " ++ show className ++ " " ++ unwords (map prettyPrintTypeAtom tys)
d : _ -> return $ dictionaryValueToValue d
chooseSimplestDictionaries :: [DictionaryValue] -> [DictionaryValue]
chooseSimplestDictionaries ds = case filter isSimpleDictionaryValue ds of
[] -> ds
simple -> simple
isSimpleDictionaryValue SubclassDictionaryValue{} = False
isSimpleDictionaryValue (DependentDictionaryValue _ ds) = all isSimpleDictionaryValue ds
isSimpleDictionaryValue _ = True
overlapping :: DictionaryValue -> DictionaryValue -> Bool
overlapping (LocalDictionaryValue nm1) (LocalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (GlobalDictionaryValue nm1) (GlobalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (DependentDictionaryValue nm1 ds1) (DependentDictionaryValue nm2 ds2)
| nm1 == nm2 = or $ zipWith overlapping ds1 ds2
overlapping SubclassDictionaryValue{} _ = False
overlapping _ SubclassDictionaryValue{} = False
overlapping _ _ = True
prettyPrintDictionaryValue :: DictionaryValue -> String
prettyPrintDictionaryValue = unlines . indented 0
where
indented n (LocalDictionaryValue _) = [spaces n ++ "Dictionary in scope"]
indented n (GlobalDictionaryValue nm) = [spaces n ++ show nm]
indented n (DependentDictionaryValue nm args) = (spaces n ++ show nm ++ " via") : concatMap (indented (n + 2)) args
indented n (SubclassDictionaryValue sup nm _) = (spaces n ++ show nm ++ " via superclass") : indented (n + 2) sup
spaces n = replicate n ' ' ++ "- "
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Environment -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ _ (Skolem _ s1 _) (Skolem _ s2 _) | s1 == s2 = Just []
typeHeadsAreEqual _ _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m e (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m e h1 h2
<*> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual m e (SaturatedTypeSynonym name args) t2 = case expandTypeSynonym' e name args of
Left _ -> Nothing
Right t1 -> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual _ _ REmpty REmpty = Just []
typeHeadsAreEqual m e r1@(RCons _ _ _) r2@(RCons _ _ _) =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in (++) <$> foldMap (\(t1, t2) -> typeHeadsAreEqual m e t1 t2) int
<*> go sd1 r1' sd2 r2'
where
go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
go [] REmpty [] REmpty = Just []
go [] (TUnknown _) _ _ = Just []
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = Just []
go [] (Skolem _ s1 _) [] (Skolem _ s2 _) | s1 == s2 = Just []
go sd r [] (TypeVar v) = Just [(v, rowFromList (sd, r))]
go _ _ _ _ = Nothing
typeHeadsAreEqual _ _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs