module Language.PureScript.TypeChecker.Entailment (
entails
) where
import Data.Function (on)
import Data.List
import Data.Maybe (maybeToList)
import qualified Data.Map as M
import Control.Applicative
import Control.Arrow (Arrow(..))
import Control.Monad.Error
import Language.PureScript.AST
import Language.PureScript.Environment
import Language.PureScript.Names
import Language.PureScript.Pretty
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Synonyms
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
data DictionaryValue
= LocalDictionaryValue (Qualified Ident)
| GlobalDictionaryValue (Qualified Ident)
| DependentDictionaryValue (Qualified Ident) [DictionaryValue]
| SubclassDictionaryValue DictionaryValue (Qualified ProperName) Integer
deriving (Show, Ord, Eq)
entails :: Environment -> ModuleName -> [TypeClassDictionaryInScope] -> (Qualified ProperName, [Type]) -> Bool -> Check Expr
entails env moduleName context = solve (sortedNubBy canonicalizeDictionary (filter filterModule context))
where
sortedNubBy :: (Ord k) => (v -> k) -> [v] -> [v]
sortedNubBy f vs = M.elems (M.fromList (map (f &&& id) vs))
filterModule :: TypeClassDictionaryInScope -> Bool
filterModule (TypeClassDictionaryInScope { tcdName = Qualified (Just mn) _ }) | mn == moduleName = True
filterModule (TypeClassDictionaryInScope { tcdName = Qualified Nothing _ }) = True
filterModule _ = False
solve context' (className, tys) trySuperclasses =
checkOverlaps $ go trySuperclasses className tys
where
go trySuperclasses' className' tys' =
[ mkDictionary (canonicalizeDictionary tcd) args
| tcd <- context'
, className' == tcdClassName tcd
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' (tcdInstanceTypes tcd)
, args <- solveSubgoals subst (tcdDependencies tcd) ] ++
[ SubclassDictionaryValue suDict superclass index
| trySuperclasses'
, (subclassName, (args, _, implies)) <- M.toList (typeClasses env)
, (index, (superclass, suTyArgs)) <- zip [0..] implies
, className' == superclass
, subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' suTyArgs
, args' <- maybeToList $ mapM ((`lookup` subst) . fst) args
, suDict <- go True subclassName args' ]
solveSubgoals :: [(String, Type)] -> Maybe [(Qualified ProperName, [Type])] -> [Maybe [DictionaryValue]]
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) = do
dict <- mapM (uncurry (go True) . second (map (replaceAllTypeVars subst))) subgoals
return $ Just dict
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ show superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) subst
guard (all (pairwise (unifiesWith env) . map snd) grps)
return $ map head grps
checkOverlaps :: [DictionaryValue] -> Check Expr
checkOverlaps dicts =
case [ (d1, d2) | d1 <- dicts, d2 <- dicts, d1 `overlapping` d2 ] of
(d1, d2) : _ -> throwError . strMsg $ unlines
[ "Overlapping instances found for " ++ show className ++ " " ++ unwords (map prettyPrintType tys) ++ "."
, "For example:"
, prettyPrintDictionaryValue d1
, "and:"
, prettyPrintDictionaryValue d2
]
_ -> case chooseSimplestDictionaries dicts of
[] -> throwError . strMsg $
"No instance found for " ++ show className ++ " " ++ unwords (map prettyPrintTypeAtom tys)
d : _ -> return $ dictionaryValueToValue d
chooseSimplestDictionaries :: [DictionaryValue] -> [DictionaryValue]
chooseSimplestDictionaries ds = case filter isSimpleDictionaryValue ds of
[] -> ds
simple -> simple
isSimpleDictionaryValue SubclassDictionaryValue{} = False
isSimpleDictionaryValue (DependentDictionaryValue _ ds) = all isSimpleDictionaryValue ds
isSimpleDictionaryValue _ = True
overlapping :: DictionaryValue -> DictionaryValue -> Bool
overlapping (LocalDictionaryValue nm1) (LocalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (GlobalDictionaryValue nm1) (GlobalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (DependentDictionaryValue nm1 ds1) (DependentDictionaryValue nm2 ds2)
| nm1 == nm2 = or $ zipWith overlapping ds1 ds2
overlapping SubclassDictionaryValue{} _ = False
overlapping _ SubclassDictionaryValue{} = False
overlapping _ _ = True
prettyPrintDictionaryValue :: DictionaryValue -> String
prettyPrintDictionaryValue = unlines . indented 0
where
indented n (LocalDictionaryValue _) = [spaces n ++ "Dictionary in scope"]
indented n (GlobalDictionaryValue nm) = [spaces n ++ show nm]
indented n (DependentDictionaryValue nm args) = (spaces n ++ show nm ++ " via") : concatMap (indented (n + 2)) args
indented n (SubclassDictionaryValue sup nm _) = (spaces n ++ show nm ++ " via superclass") : indented (n + 2) sup
spaces n = replicate n ' ' ++ "- "
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Environment -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ _ (Skolem _ s1 _) (Skolem _ s2 _) | s1 == s2 = Just []
typeHeadsAreEqual _ _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m e (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m e h1 h2 <*> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual m e (SaturatedTypeSynonym name args) t2 = case expandTypeSynonym' e name args of
Left _ -> Nothing
Right t1 -> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual _ _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs