module Language.PureScript.TypeChecker.Entailment (
entails
) where
import Data.Function (on)
import Data.List
import Data.Maybe (maybeToList)
import Data.Foldable (foldMap)
import qualified Data.Map as M
import Control.Applicative
import Control.Arrow (Arrow(..))
import Control.Monad.State
import Control.Monad.Error.Class (MonadError(..))
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Environment
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Synonyms
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
newtype Work = Work Integer deriving (Show, Eq, Ord, Num)
entails :: Environment -> ModuleName -> M.Map (Maybe ModuleName) (M.Map (Qualified ProperName) (M.Map (Qualified Ident) TypeClassDictionaryInScope)) -> Constraint -> Bool -> Check Expr
entails env moduleName context = solve
where
forClassName :: Qualified ProperName -> [TypeClassDictionaryInScope]
forClassName cn = findDicts cn Nothing ++ findDicts cn (Just moduleName)
findDicts :: Qualified ProperName -> Maybe ModuleName -> [TypeClassDictionaryInScope]
findDicts cn = maybe [] M.elems . (>>= M.lookup cn) . flip M.lookup context
solve :: Constraint -> Bool -> Check Expr
solve (className, tys) trySuperclasses = do
let dicts = flip evalStateT (Work 0) $ go trySuperclasses className tys
checkOverlaps dicts
where
go :: Bool -> Qualified ProperName -> [Type] -> StateT Work [] DictionaryValue
go trySuperclasses' className' tys' = do
workDone <- get
guard $ workDone < 1000
modify (1 +)
directInstances <|> superclassInstances
where
directInstances :: StateT Work [] DictionaryValue
directInstances = do
tcd <- lift $ forClassName className'
subst <- lift . maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' (tcdInstanceTypes tcd)
args <- solveSubgoals subst (tcdDependencies tcd)
return $ mkDictionary (canonicalizeDictionary tcd) args
superclassInstances :: StateT Work [] DictionaryValue
superclassInstances = do
guard trySuperclasses'
(subclassName, (args, _, implies)) <- lift $ M.toList (typeClasses env)
(index, (superclass, suTyArgs)) <- lift $ zip [0..] implies
guard $ className' == superclass
subst <- lift . maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' suTyArgs
args' <- lift . maybeToList $ mapM ((`lookup` subst) . fst) args
suDict <- go True subclassName args'
return $ SubclassDictionaryValue suDict superclass index
solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> StateT Work [] (Maybe [DictionaryValue])
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) = do
dict <- mapM (uncurry (go True) . second (map (replaceAllTypeVars subst))) subgoals
return $ Just dict
mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
mkDictionary fnName Nothing = LocalDictionaryValue fnName
mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
dictionaryValueToValue :: DictionaryValue -> Expr
dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
App (Accessor (C.__superclass_ ++ show superclassName ++ "_" ++ show index)
(dictionaryValueToValue dict))
valUndefined
verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
verifySubstitution subst = do
let grps = groupBy ((==) `on` fst) subst
guard (all (pairwise (unifiesWith env) . map snd) grps)
return $ map head grps
checkOverlaps :: [DictionaryValue] -> Check Expr
checkOverlaps dicts =
case [ (d1, d2) | d1 <- dicts, d2 <- dicts, d1 `overlapping` d2 ] of
ds@(_ : _) -> throwError . errorMessage $ OverlappingInstances className tys $ nub (map fst ds)
_ -> case chooseSimplestDictionaries dicts of
[] -> throwError . errorMessage $ NoInstanceFound className tys
d : _ -> return $ dictionaryValueToValue d
chooseSimplestDictionaries :: [DictionaryValue] -> [DictionaryValue]
chooseSimplestDictionaries ds = case filter isSimpleDictionaryValue ds of
[] -> ds
simple -> simple
isSimpleDictionaryValue SubclassDictionaryValue{} = False
isSimpleDictionaryValue (DependentDictionaryValue _ ds) = all isSimpleDictionaryValue ds
isSimpleDictionaryValue _ = True
overlapping :: DictionaryValue -> DictionaryValue -> Bool
overlapping (LocalDictionaryValue nm1) (LocalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (GlobalDictionaryValue nm1) (GlobalDictionaryValue nm2) | nm1 == nm2 = False
overlapping (DependentDictionaryValue nm1 ds1) (DependentDictionaryValue nm2 ds2)
| nm1 == nm2 = or $ zipWith overlapping ds1 ds2
overlapping SubclassDictionaryValue{} _ = False
overlapping _ SubclassDictionaryValue{} = False
overlapping _ _ = True
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
typeHeadsAreEqual :: ModuleName -> Environment -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ _ (Skolem _ s1 _) (Skolem _ s2 _) | s1 == s2 = Just []
typeHeadsAreEqual _ _ t (TypeVar v) = Just [(v, t)]
typeHeadsAreEqual _ _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m e (TypeApp h1 t1) (TypeApp h2 t2) = (++) <$> typeHeadsAreEqual m e h1 h2
<*> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual m e (SaturatedTypeSynonym name args) t2 = case expandTypeSynonym' e name args of
Left _ -> Nothing
Right t1 -> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual _ _ REmpty REmpty = Just []
typeHeadsAreEqual m e r1@(RCons _ _ _) r2@(RCons _ _ _) =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in (++) <$> foldMap (\(t1, t2) -> typeHeadsAreEqual m e t1 t2) int
<*> go sd1 r1' sd2 r2'
where
go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
go [] REmpty [] REmpty = Just []
go [] (TUnknown _) _ _ = Just []
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = Just []
go [] (Skolem _ s1 _) [] (Skolem _ s2 _) | s1 == s2 = Just []
go sd r [] (TypeVar v) = Just [(v, rowFromList (sd, r))]
go _ _ _ _ = Nothing
typeHeadsAreEqual _ _ _ _ = Nothing
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs