module Type.Unify (unify) where
import Control.Applicative ((<|>))
import Control.Monad.State
import qualified Data.List as List
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.UnionFind.IO as UF
import qualified AST.Annotation as A
import qualified AST.Variable as Var
import qualified Type.State as TS
import Type.Type
import Type.PrettyPrint
import Text.PrettyPrint (render)
import Elm.Utils ((|>))
unify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO ()
unify region variable1 variable2 = do
equivalent <- liftIO $ UF.equivalent variable1 variable2
if equivalent
then return ()
else actuallyUnify region variable1 variable2
actuallyUnify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO ()
actuallyUnify region variable1 variable2 = do
desc1 <- liftIO $ UF.descriptor variable1
desc2 <- liftIO $ UF.descriptor variable2
let unify' = unify region
(name', flex', rank', alias') = combinedDescriptors desc1 desc2
merge1 :: StateT TS.SolverState IO ()
merge1 = liftIO $ do
if rank desc1 < rank desc2 then UF.union variable2 variable1
else UF.union variable1 variable2
UF.modifyDescriptor variable1 $ \desc ->
desc { structure = structure desc1
, flex = flex'
, name = name'
, alias = alias'
}
merge2 :: StateT TS.SolverState IO ()
merge2 = liftIO $ do
if rank desc1 < rank desc2 then UF.union variable2 variable1
else UF.union variable1 variable2
UF.modifyDescriptor variable2 $ \desc ->
desc { structure = structure desc2
, flex = flex'
, name = name'
, alias = alias'
}
merge = if rank desc1 < rank desc2 then merge1 else merge2
fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable
fresh structure = do
v <- liftIO . UF.fresh $ Descriptor
{ structure = structure
, rank = rank'
, flex = flex'
, name = name'
, copy = Nothing
, mark = noMark
, alias = alias'
}
TS.register v
flexAndUnify v = do
liftIO $ UF.modifyDescriptor v $ \desc -> desc { flex = Flexible }
unify' variable1 variable2
unifyNumber svar (Var.Canonical home name) =
case home of
Var.BuiltIn | name `elem` ["Int","Float"] -> flexAndUnify svar
Var.Local | List.isPrefixOf "number" name -> flexAndUnify svar
_ ->
let hint = "Looks like something besides an Int or Float is being used as a number."
in
TS.addError region (Just hint) variable1 variable2
comparableError maybe =
TS.addError region (Just $ Maybe.fromMaybe msg maybe) variable1 variable2
where
msg =
"Looks like you want something comparable, but the only valid comparable\n\
\types are Int, Float, Char, String, lists, or tuples."
appendableError maybe =
TS.addError region (Just $ Maybe.fromMaybe msg maybe) variable1 variable2
where
msg =
"Looks like you want something appendable, but the only Strings, Lists,\n\
\and Text can be appended with the (++) operator."
unifyComparable v (Var.Canonical home name) =
case home of
Var.BuiltIn | name `elem` ["Int","Float","Char","String"] -> flexAndUnify v
Var.Local | List.isPrefixOf "comparable" name -> flexAndUnify v
_ -> comparableError Nothing
unifyComparableStructure varSuper varFlex =
do struct <- liftIO $ collectApps varFlex
case struct of
Other -> comparableError Nothing
List v -> do flexAndUnify varSuper
unify' v =<< liftIO (variable $ Is Comparable)
Tuple vs
| length vs > 6 ->
comparableError $ Just "Cannot compare a tuple with more than 6 elements."
| otherwise ->
do flexAndUnify varSuper
cmpVars <- liftIO $ forM [1..length vs] $ \_ -> variable (Is Comparable)
zipWithM_ unify' vs cmpVars
unifyAppendable varSuper varFlex =
do struct <- liftIO $ collectApps varFlex
case struct of
List _ -> flexAndUnify varSuper
_ -> appendableError Nothing
rigidError var =
TS.addError region (Just hint) variable1 variable2
where
hint =
"Could not unify rigid type variable '" ++ render (pretty Never var) ++ "'.\n" ++
"The problem probably relates to the type variable being shared between a\n\
\top-level type annotation and a related let-bound type annotation."
superUnify =
case (flex desc1, flex desc2, name desc1, name desc2) of
(Is super1, Is super2, _, _)
| super1 == super2 -> merge
(Is Number, Is Comparable, _, _) -> merge1
(Is Comparable, Is Number, _, _) -> merge2
(Is Number, _, _, Just name) -> unifyNumber variable1 name
(_, Is Number, Just name, _) -> unifyNumber variable2 name
(Is Comparable, _, _, Just name) -> unifyComparable variable1 name
(_, Is Comparable, Just name, _) -> unifyComparable variable2 name
(Is Comparable, _, _, _) -> unifyComparableStructure variable1 variable2
(_, Is Comparable, _, _) -> unifyComparableStructure variable2 variable1
(Is Appendable, _, _, Just name)
| Var.isText name || Var.isPrim "String" name -> flexAndUnify variable1
(_, Is Appendable, Just name, _)
| Var.isText name || Var.isPrim "String" name -> flexAndUnify variable2
(Is Appendable, _, _, _) -> unifyAppendable variable1 variable2
(_, Is Appendable, _, _) -> unifyAppendable variable2 variable1
(Rigid, _, _, _) -> rigidError variable1
(_, Rigid, _, _) -> rigidError variable2
_ -> TS.addError region Nothing variable1 variable2
case (structure desc1, structure desc2) of
(Nothing, Nothing) | flex desc1 == Flexible && flex desc1 == Flexible -> merge
(Nothing, _) | flex desc1 == Flexible -> merge2
(_, Nothing) | flex desc2 == Flexible -> merge1
(Just (Var1 v), _) -> unify' v variable2
(_, Just (Var1 v)) -> unify' v variable1
(Nothing, _) -> superUnify
(_, Nothing) -> superUnify
(Just type1, Just type2) ->
case (type1,type2) of
(App1 term1 term2, App1 term1' term2') ->
do merge
unify' term1 term1'
unify' term2 term2'
(Fun1 term1 term2, Fun1 term1' term2') ->
do merge
unify' term1 term1'
unify' term2 term2'
(EmptyRecord1, EmptyRecord1) ->
return ()
(Record1 fields ext, EmptyRecord1) | Map.null fields -> unify' ext variable2
(EmptyRecord1, Record1 fields ext) | Map.null fields -> unify' ext variable1
(Record1 _ _, Record1 _ _) ->
recordUnify region fresh variable1 variable2
_ -> TS.addError region Nothing variable1 variable2
recordUnify
:: A.Region
-> (Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable)
-> Variable
-> Variable
-> StateT TS.SolverState IO ()
recordUnify region fresh variable1 variable2 =
do (ExpandedRecord fields1 ext1) <- liftIO (gatherFields variable1)
(ExpandedRecord fields2 ext2) <- liftIO (gatherFields variable2)
unifyOverlappingFields region fields1 fields2
let freshRecord fields ext =
fresh (Just (Record1 fields ext))
let uniqueFields1 = diffFields fields1 fields2
let uniqueFields2 = diffFields fields2 fields1
let addFieldMismatchError missingFields =
let msg = fieldMismatchError missingFields
in
TS.addError region (Just msg) variable1 variable2
case (ext1, ext2) of
(Empty _, Empty _) ->
case Map.null uniqueFields1 && Map.null uniqueFields2 of
True -> return ()
False -> TS.addError region Nothing variable1 variable2
(Empty var1, Extension var2) ->
case (Map.null uniqueFields1, Map.null uniqueFields2) of
(_, False) -> addFieldMismatchError uniqueFields2
(True, True) -> unify region var1 var2
(False, True) ->
do subRecord <- freshRecord uniqueFields1 var1
unify region subRecord var2
(Extension var1, Empty var2) ->
case (Map.null uniqueFields1, Map.null uniqueFields2) of
(False, _) -> addFieldMismatchError uniqueFields1
(True, True) -> unify region var1 var2
(True, False) ->
do subRecord <- freshRecord uniqueFields2 var2
unify region var1 subRecord
(Extension var1, Extension var2) ->
case (Map.null uniqueFields1, Map.null uniqueFields2) of
(True, True) ->
unify region var1 var2
(True, False) ->
do subRecord <- freshRecord uniqueFields2 var2
unify region var1 subRecord
(False, True) ->
do subRecord <- freshRecord uniqueFields1 var1
unify region subRecord var2
(False, False) ->
do record1' <- freshRecord uniqueFields1 =<< fresh Nothing
record2' <- freshRecord uniqueFields2 =<< fresh Nothing
unify region record1' var2
unify region var1 record2'
unifyOverlappingFields
:: A.Region
-> Map.Map String [Variable]
-> Map.Map String [Variable]
-> StateT TS.SolverState IO ()
unifyOverlappingFields region fields1 fields2 =
Map.intersectionWith (zipWith (unify region)) fields1 fields2
|> Map.elems
|> concat
|> sequence_
diffFields :: Map.Map String [a] -> Map.Map String [a] -> Map.Map String [a]
diffFields fields1 fields2 =
let eat (_:xs) (_:ys) = eat xs ys
eat xs _ = xs
in
Map.union (Map.intersectionWith eat fields1 fields2) fields1
|> Map.filter (not . null)
data ExpandedRecord = ExpandedRecord
{ _fields :: Map.Map String [Variable]
, _extension :: Extension
}
data Extension = Empty Variable | Extension Variable
gatherFields :: Variable -> IO ExpandedRecord
gatherFields var =
do desc <- UF.descriptor var
case structure desc of
(Just (Record1 fields ext)) ->
do (ExpandedRecord deeperFields rootExt) <- gatherFields ext
return (ExpandedRecord (Map.unionWith (++) fields deeperFields) rootExt)
(Just EmptyRecord1) ->
return (ExpandedRecord Map.empty (Empty var))
_ ->
return (ExpandedRecord Map.empty (Extension var))
fieldMismatchError :: Map.Map String a -> String
fieldMismatchError missingFields =
case Map.keys missingFields of
[] -> ""
[key] ->
"Looks like a record is missing the field '" ++ key ++ "'.\n " ++
"Maybe there is a misspelling in a record access or record update?"
keys ->
"Looks like one record is missing fields "
++ List.intercalate ", " (init keys) ++ ", and " ++ last keys
combinedDescriptors :: Descriptor -> Descriptor
-> (Maybe Var.Canonical, Flex, Int, Maybe Var.Canonical)
combinedDescriptors desc1 desc2 =
(name', flex', rank', alias')
where
rank' :: Int
rank' = min (rank desc1) (rank desc2)
alias' :: Maybe Var.Canonical
alias' = alias desc1 <|> alias desc2
name' :: Maybe Var.Canonical
name' = case (name desc1, name desc2) of
(Just name1, Just name2) ->
case (flex desc1, flex desc2) of
(_, Flexible) -> Just name1
(Flexible, _) -> Just name2
(Is Number, Is _) -> Just name1
(Is _, Is Number) -> Just name2
(Is _, Is _) -> Just name1
(_, _) -> Nothing
(Just name1, _) -> Just name1
(_, Just name2) -> Just name2
_ -> Nothing
flex' :: Flex
flex' = case (flex desc1, flex desc2) of
(f, Flexible) -> f
(Flexible, f) -> f
(Is Number, Is _) -> Is Number
(Is _, Is Number) -> Is Number
(Is super, Is _) -> Is super
(_, _) -> Flexible