module Type.Unify (unify) where import Type.Type import qualified Data.UnionFind.IO as UF import qualified Data.Map as Map import qualified Type.State as TS import Control.Arrow (first,second) import Control.Monad.State import SourceSyntax.Location import Type.PrettyPrint import Text.PrettyPrint (render) unify :: SrcSpan -> Variable -> Variable -> StateT TS.SolverState IO () unify span variable1 variable2 = do equivalent <- liftIO $ UF.equivalent variable1 variable2 if equivalent then return () else actuallyUnify span variable1 variable2 actuallyUnify :: SrcSpan -> Variable -> Variable -> StateT TS.SolverState IO () actuallyUnify span variable1 variable2 = do desc1 <- liftIO $ UF.descriptor variable1 desc2 <- liftIO $ UF.descriptor variable2 let unify' = unify span name' :: Maybe String name' = case (name desc1, name desc2) of (Just name1, Just name2) -> case (flex desc1, flex desc2) of (_, Flexible) -> Just name1 (Flexible, _) -> Just name2 (Is Number, Is _) -> Just name1 (Is _, Is Number) -> Just name2 (Is _, Is _) -> Just name1 (_, _) -> Nothing (Just name1, _) -> Just name1 (_, Just name2) -> Just name2 _ -> Nothing flex' :: Flex flex' = case (flex desc1, flex desc2) of (f, Flexible) -> f (Flexible, f) -> f (Is Number, Is _) -> Is Number (Is _, Is Number) -> Is Number (Is super, Is _) -> Is super (_, _) -> Flexible rank' :: Int rank' = min (rank desc1) (rank desc2) merge1 :: StateT TS.SolverState IO () merge1 = liftIO $ do if rank desc1 < rank desc2 then UF.union variable2 variable1 else UF.union variable1 variable2 UF.modifyDescriptor variable1 $ \desc -> desc { structure = structure desc1, flex = flex', name = name' } merge2 :: StateT TS.SolverState IO () merge2 = liftIO $ do if rank desc1 < rank desc2 then UF.union variable2 variable1 else UF.union variable1 variable2 UF.modifyDescriptor variable2 $ \desc -> desc { structure = structure desc2, flex = flex', name = name' } merge = if rank desc1 < rank desc2 then merge1 else merge2 fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable fresh structure = do var <- liftIO . UF.fresh $ Descriptor { structure = structure, rank = rank', flex = flex', name = name', copy = Nothing, mark = noMark } TS.register var flexAndUnify var = do liftIO $ UF.modifyDescriptor var $ \desc -> desc { flex = Flexible } unify' variable1 variable2 unifyNumber svar name | name `elem` ["Int","Float","number"] = flexAndUnify svar | otherwise = TS.addError span "Expecting a number (Int or Float)" variable1 variable2 comparableError str = TS.addError span (str ++ msg) variable1 variable2 where msg = "Expecting something comparable such as an\n" ++ "Int, Float, Char, String, or a list or tuple of comparables." unifyComparable var name | name `elem` ["Int","Float","Char","String","comparable"] = flexAndUnify var | otherwise = comparableError "" unifyComparableStructure varSuper varFlex = do struct <- liftIO $ collectApps varFlex case struct of Other -> comparableError "" List v -> do flexAndUnify varSuper unify' v =<< liftIO (var $ Is Comparable) Tuple vs | length vs > 6 -> comparableError "Cannot compare a tuple with more than 6 elements.\n" | otherwise -> do flexAndUnify varSuper cmpVars <- liftIO $ forM [1..length vs] $ \_ -> var (Is Comparable) zipWithM_ unify' vs cmpVars unifyAppendable varSuper varFlex = do struct <- liftIO $ collectApps varFlex case struct of List _ -> flexAndUnify varSuper _ -> comparableError "" rigidError variable = TS.addError span msg variable1 variable2 where msg = concat [ "Cannot unify rigid type variable '", render (pretty Never variable), "'.\n" , "It is likely that a type annotation is not general enough." ] superUnify = case (flex desc1, flex desc2, name desc1, name desc2) of (Is super1, Is super2, _, _) | super1 == super2 -> merge (Is Number, Is Comparable, _, _) -> merge1 (Is Comparable, Is Number, _, _) -> merge2 (Is Number, _, _, Just name) -> unifyNumber variable1 name (_, Is Number, Just name, _) -> unifyNumber variable2 name (Is Comparable, _, _, Just name) -> unifyComparable variable1 name (_, Is Comparable, Just name, _) -> unifyComparable variable2 name (Is Comparable, _, _, _) -> unifyComparableStructure variable1 variable2 (_, Is Comparable, _, _) -> unifyComparableStructure variable2 variable1 (Is Appendable, _, _, Just ctor) | ctor `elem` ["Text.Text","String"] -> flexAndUnify variable1 (_, Is Appendable, Just ctor, _) | ctor `elem` ["Text.Text","String"] -> flexAndUnify variable2 (Is Appendable, _, _, _) -> unifyAppendable variable1 variable2 (_, Is Appendable, _, _) -> unifyAppendable variable2 variable1 (Rigid, _, _, _) -> rigidError variable1 (_, Rigid, _, _) -> rigidError variable2 _ -> TS.addError span "" variable1 variable2 case (structure desc1, structure desc2) of (Nothing, Nothing) | flex desc1 == Flexible && flex desc1 == Flexible -> merge (Nothing, _) | flex desc1 == Flexible -> merge2 (_, Nothing) | flex desc2 == Flexible -> merge1 (Just (Var1 v), _) -> unify' v variable2 (_, Just (Var1 v)) -> unify' v variable1 (Nothing, _) -> superUnify (_, Nothing) -> superUnify (Just type1, Just type2) -> case (type1,type2) of (App1 term1 term2, App1 term1' term2') -> do merge unify' term1 term1' unify' term2 term2' (Fun1 term1 term2, Fun1 term1' term2') -> do merge unify' term1 term1' unify' term2 term2' (EmptyRecord1, EmptyRecord1) -> return () (Record1 fields ext, EmptyRecord1) | Map.null fields -> unify' ext variable2 (EmptyRecord1, Record1 fields ext) | Map.null fields -> unify' ext variable1 (Record1 fields1 ext1, Record1 fields2 ext2) -> do sequence . concat . Map.elems $ Map.intersectionWith (zipWith unify') fields1 fields2 let mkRecord fs ext = fresh . Just $ Record1 fs ext case (Map.null fields1', Map.null fields2') of (True , True ) -> unify' ext1 ext2 (True , False) -> do record2' <- mkRecord fields2' ext2 unify' ext1 record2' (False, True ) -> do record1' <- mkRecord fields1' ext1 unify' record1' ext2 (False, False) -> do record1' <- mkRecord fields1' =<< fresh Nothing record2' <- mkRecord fields2' =<< fresh Nothing unify' record1' ext2 unify' ext1 record2' where fields1' = unmerged fields1 fields2 fields2' = unmerged fields2 fields1 unmerged a b = Map.filter (not . null) $ Map.union (Map.intersectionWith eat a b) a eat (x:xs) (y:ys) = eat xs ys eat xs ys = xs _ -> TS.addError span "" variable1 variable2