module Infernu.Unify
(unify, unifyAll, unifyl, unifyRowPropertyBiased, unifyPredsL, unifyPending)
where
import Control.Monad (forM, forM_, when, unless)
import Data.List (intercalate)
import Data.Either (rights)
import Data.Map.Lazy (Map)
import qualified Data.Map.Lazy as Map
import Data.Maybe (catMaybes, mapMaybe)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (foldl, foldr, mapM, sequence)
import Infernu.Prelude
import Infernu.Builtins.Array (arrayRowType)
import Infernu.Builtins.Regex (regexRowType)
import Infernu.Builtins.String (stringRowType)
import Infernu.Decycle
import Infernu.InferState
import Infernu.Lib (matchZip)
import Infernu.Log
import Infernu.Pretty
import Infernu.Types
tryMakeRow :: FType Type -> Infer (Maybe (TRowList Type))
tryMakeRow (TCons TArray [t]) = Just <$> arrayRowType t
tryMakeRow (TBody TRegex) = Just <$> regexRowType
tryMakeRow (TBody TString) = Just <$> stringRowType
tryMakeRow _ = return Nothing
type UnifyF = Source -> Type -> Type -> Infer ()
unify :: UnifyF
unify = decycledUnify
decycledUnify :: UnifyF
decycledUnify = decycle3 unify''
unlessEq :: (Monad m, Eq a) => a -> a -> m () -> m ()
unlessEq x y = unless (x == y)
mkTypeErrorMessage :: Pretty a => a -> a -> Maybe TypeError -> [Char]
mkTypeErrorMessage t1 t2 mte =
concat [ "\n"
, " Failed unifying: "
, prettyTab 6 t1
, "\n"
, " With: "
, prettyTab 6 t2
, case mte of
Nothing -> ""
Just te -> "\n Because: " ++ prettyTab 2 (message te)
]
unify'' :: Maybe UnifyF -> UnifyF
unify'' Nothing _ t1 t2 = traceLog $ "breaking infinite recursion cycle, when unifying: " ++ pretty t1 ++ " ~ " ++ pretty t2
unify'' (Just recurse) a t1 t2 =
do traceLog $ "unifying: " ++ pretty t1 ++ " ~ " ++ pretty t2
s <- getMainSubst
let t1' = unFix $ applySubst s t1
t2' = unFix $ applySubst s t2
traceLog $ "unifying (substed): " ++ pretty t1 ++ " ~ " ++ pretty t2
let wrap' te = TypeError { source = source te,
message = mkTypeErrorMessage t1 t2 (Just te)
}
mapError wrap' $ unify' recurse a t1' t2'
unificationError :: (VarNames x, Pretty x) => Source -> x -> x -> Infer b
unificationError pos x y = throwError pos $ mkTypeErrorMessage a b Nothing
where [a, b] = minifyVars [x, y]
assertNoPred :: QualType -> Infer Type
assertNoPred q =
do unless (null $ qualPred q) $ fail $ "Assertion failed: pred in " ++ pretty q
return $ qualType q
unify' :: UnifyF -> Source -> FType (Fix FType) -> FType (Fix FType) -> Infer ()
unify' _ a (TBody (TVar n)) t = varBind a n (Fix t)
unify' _ a t (TBody (TVar n)) = varBind a n (Fix t)
unify' _ a (TBody x) (TBody y) = unlessEq x y $ unificationError a x y
unify' recurse a t1@(TCons (TName n1) targs1) t2@(TCons (TName n2) targs2) =
if n1 == n2
then case matchZip targs1 targs2 of
Nothing -> unificationError a t1 t2
Just targs -> unifyl recurse a targs
else
do let unroll' = unrollName a
t1' <- unroll' n1 targs1
t2' <- unroll' n2 targs2
mapM_ assertNoPred [t1', t2']
recurse a (qualType t1') (qualType t2')
unify' recurse a (TCons (TName n1) targs1) t2 =
unrollName a n1 targs1
>>= assertNoPred
>>= flip (recurse a) (Fix t2)
unify' recurse a t1 (TCons (TName n2) targs2) =
unrollName a n2 targs2
>>= assertNoPred
>>= recurse a (Fix t1)
unify' _ a t1@(TBody _) t2@(TCons _ _) = unificationError a t1 t2
unify' _ a t1@(TCons _ _) t2@(TBody _) = unificationError a t1 t2
unify' _ a t1@(TBody _) t2@(TFunc _ _) = unificationError a t1 t2
unify' _ a t1@(TFunc _ _) t2@(TBody _) = unificationError a t1 t2
unify' _ a t1@(TFunc _ _) t2@(TCons _ _) = unificationError a t1 t2
unify' _ a t1@(TCons _ _) t2@(TFunc _ _) = unificationError a t1 t2
unify' recurse a t1@(TCons n1 ts1) t2@(TCons n2 ts2) =
do when (n1 /= n2) $ unificationError a t1 t2
case matchZip ts1 ts2 of
Nothing -> unificationError a t1 t2
Just ts -> unifyl recurse a ts
unify' recurse a t1@(TFunc ts1 tres1) t2@(TFunc ts2 tres2) =
case matchZip ts1 ts2 of
Nothing -> unificationError a t1 t2
Just ts -> do loop' ts
recurse a tres2 tres1
where loop' [] = return ()
loop' ((Fix (TBody TUndefined), _):ts') = loop' ts'
loop' ((x,y):ts') =
do recurse a x y
loop' ts'
unify' r a (TRow tRowList) t2@(TCons _ _) = unifyTryMakeRow r a True tRowList t2
unify' r a t1@(TCons _ _) (TRow tRowList) = unifyTryMakeRow r a False tRowList t1
unify' r a (TRow tRowList) t2@(TBody _) = unifyTryMakeRow r a True tRowList t2
unify' r a t1@(TBody _) (TRow tRowList) = unifyTryMakeRow r a False tRowList t1
unify' r a (TRow tRowList) t2@(TFunc _ _) = unifyTryMakeRow r a True tRowList t2
unify' r a t1@(TFunc _ _) (TRow tRowList) = unifyTryMakeRow r a False tRowList t1
unify' recurse a t1@(TRow row1) t2@(TRow row2) =
unlessEq t1 t2 $ do
let (m2, r2) = flattenRow row2
names2 = Set.fromList $ Map.keys m2
(m1, r1) = flattenRow row1
names1 = Set.fromList $ Map.keys m1
commonNames = Set.toList $ names1 `Set.intersection` names2
namesToTypes m = mapMaybe $ flip Map.lookup m
commonTypes = zip (namesToTypes m1 commonNames) (namesToTypes m2 commonNames)
forM_ commonTypes $ \(ts1, ts2) -> unifyRowPropertyBiased' recurse a (unificationError a ts1 ts2) (ts1, ts2)
r <- RowTVar <$> fresh
unifyRows recurse a r (t1, names1, m1) (t2, names2, r2)
unifyRows recurse a r (t2, names2, m2) (t1, names1, r1)
unifyTryMakeRow :: UnifyF -> Source -> Bool -> TRowList Type -> FType Type -> Infer ()
unifyTryMakeRow r a leftBiased tRowList t2 =
do let tRow = TRow tRowList
res <- tryMakeRow t2
case res of
Nothing -> unificationError a tRow t2
Just rowType -> if leftBiased
then r a (Fix tRow) (Fix $ TRow rowType)
else r a (Fix $ TRow rowType) (Fix tRow)
unifyRowPropertyBiased :: Source -> Infer () -> (TypeScheme, TypeScheme) -> Infer ()
unifyRowPropertyBiased = unifyRowPropertyBiased' unify
unifyRowPropertyBiased' :: UnifyF -> Source -> Infer () -> (TypeScheme, TypeScheme) -> Infer ()
unifyRowPropertyBiased' recurse a errorAction (scheme1s, scheme2s) =
do traceLog ("Unifying type schemes: " ++ pretty scheme1s ++ " ~ " ++ pretty scheme2s)
let crap = Fix $ TBody TUndefined
unifySchemes' = do traceLog ("Unifying schemes: " ++ pretty scheme1s ++ " ~~ " ++ pretty scheme2s)
scheme1T <- instantiate scheme1s
scheme2T <- instantiate scheme2s
unifyPredsL a $ (qualPred scheme1T) ++ (qualPred scheme2T)
recurse a (qualType scheme1T) (qualType scheme2T)
isSimpleScheme =
case scheme1s of
TScheme [] _ -> True
_ -> False
unless (areEquivalentNamedTypes (crap, scheme1s) (crap, scheme2s))
$ if isSimpleScheme || (length (schemeVars scheme1s) == length (schemeVars scheme2s))
then unifySchemes'
else errorAction
unifyRows :: (VarNames x, Pretty x) => UnifyF -> Source -> RowTVar
-> (x, Set EPropName, Map EPropName TypeScheme)
-> (x, Set EPropName, FlatRowEnd Type)
-> Infer ()
unifyRows recurse a r (t1, names1, m1) (t2, names2, r2) =
do let in1NotIn2 = names1 `Set.difference` names2
rowTail = case r2 of
FlatRowEndTVar (Just _) -> FlatRowEndTVar $ Just r
_ -> r2
in1NotIn2row = tracePretty "in1NotIn2row" $ Fix . TRow . unflattenRow m1 rowTail $ flip Set.member in1NotIn2
case r2 of
FlatRowEndTVar Nothing -> if Set.null in1NotIn2
then varBind a (getRowTVar r) (Fix $ TRow $ TRowEnd Nothing)
else unificationError a t1 t2
FlatRowEndTVar (Just r2') -> recurse a in1NotIn2row (Fix . TBody . TVar $ getRowTVar r2')
FlatRowEndRec tid ts -> recurse a in1NotIn2row (Fix $ TCons (TName tid) ts)
unifyl :: UnifyF -> Source -> [(Type, Type)] -> Infer ()
unifyl r a = mapM_ $ uncurry $ r a
isInsideRowType :: TVarName -> Type -> Set Type
isInsideRowType n (Fix t) =
case t of
TRow t' -> if n `Set.member` freeTypeVars t'
then Set.singleton $ Fix t
else Set.empty
_ -> foldr (\x l -> isInsideRowType n x `Set.union` l) Set.empty t
getSingleton :: Set a -> Maybe a
getSingleton s = case foldr (:) [] s of
[x] -> Just x
_ -> Nothing
varBind :: Source -> TVarName -> Type -> Infer ()
varBind a n t =
do s <- varBind' a n t
applySubstInfer s
varBind' :: Source -> TVarName -> Type -> Infer TSubst
varBind' a n t | t == Fix (TBody (TVar n)) = return nullSubst
| Just rowT <- getSingleton $ isInsideRowType n t =
do traceLog ("===> Generalizing mu-type: " ++ pretty n ++ " recursive in: " ++ pretty t ++ ", found enclosing row type: " ++ " = " ++ pretty rowT)
recVar <- fresh
let withRecVar = replaceFix (unFix rowT) (TBody (TVar recVar)) t
recT = replaceFix (TBody (TVar n)) (unFix withRecVar) rowT
namedType <- getNamedType recVar recT
traceLog $ "===> Resulting mu type: " ++ pretty n ++ " = " ++ pretty withRecVar
return $ singletonSubst recVar namedType `composeSubst` singletonSubst n withRecVar
| n `Set.member` freeTypeVars t = let f = minifyVarsFunc t
in throwError a $ "Occurs check failed: " ++ pretty (f n) ++ " in " ++ pretty (mapVarNames f t)
| otherwise = return $ singletonSubst n t
unifyAll :: Source -> [Type] -> Infer ()
unifyAll a ts = unifyl decycledUnify a $ zip ts (drop 1 ts)
unifyPredsL :: Source -> [TPred Type] -> Infer [TPred Type]
unifyPredsL a ps = catMaybes <$>
do forM ps $ \p@(TPredIsIn className t) ->
do entry <- ((a,t,) . (className,) . Set.fromList . classInstances) <$> lookupClass className
`failWithM` throwError a ("Unknown class: " ++ pretty className ++ " in pred list: " ++ pretty ps)
remainingAmbiguities <- unifyAmbiguousEntry entry
case remainingAmbiguities of
Nothing -> return Nothing
Just ambig ->
do addPendingUnification ambig
return $ Just p
isRight :: Either a b -> Bool
isRight (Right _) = True
isRight _ = False
catLefts :: [Either a b] -> [a]
catLefts [] = []
catLefts (Left a:xs) = a:(catLefts xs)
catLefts (Right _:xs) = catLefts xs
unifyAmbiguousEntry :: (Source, Type, (ClassName, Set TypeScheme)) -> Infer (Maybe (Source, Type, (ClassName, Set TypeScheme)))
unifyAmbiguousEntry (a, t, (ClassName className, tss)) =
do let unifAction ts =
do inst <- instantiateScheme False ts >>= assertNoPred
unify a inst t
unifyResults <- forM (Set.toList tss) $ \instScheme -> (instScheme, ) <$> runSubInfer (unifAction instScheme >> getState)
let survivors = filter (isRight . snd) unifyResults
case rights $ map snd survivors of
[] -> do t' <- applyMainSubst t
throwError a $ concat [ intercalate "\n\n" $ "" : (map (prettyTab 2 . message) . catLefts $ map snd $ unifyResults)
, "\n\n"
, "While trying to find matching instance of typeclass "
, "\n "
, prettyTab 1 className
, "\nfor type:\n "
, prettyTab 1 t'
]
[newState] -> setState newState >> return Nothing
_ -> return . Just . (\x -> (a, t, (ClassName className, x))) . Set.fromList . map fst $ survivors
unifyPending :: Infer ()
unifyPending = getPendingUnifications >>= loop
where loop pu =
do newEntries <- forM (Set.toList pu) unifyAmbiguousEntry
let pu' = Set.fromList $ catMaybes newEntries
setPendingUnifications pu'
when (pu' /= pu) $ loop pu'