{-# OPTIONS_GHC -W #-}
module Type.Unify (unify) where

import Control.Monad.State
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.UnionFind.IO as UF
import qualified SourceSyntax.Annotation as A
import qualified Type.State as TS
import Type.Type
import Type.PrettyPrint
import Text.PrettyPrint (render)

unify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO ()
unify region variable1 variable2 = do
  equivalent <- liftIO $ UF.equivalent variable1 variable2
  if equivalent then return ()
                else actuallyUnify region variable1 variable2

actuallyUnify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO ()
actuallyUnify region variable1 variable2 = do
  desc1 <- liftIO $ UF.descriptor variable1
  desc2 <- liftIO $ UF.descriptor variable2
  let unify' = unify region

      name' :: Maybe String
      name' = case (name desc1, name desc2) of
                (Just name1, Just name2) ->
                    case (flex desc1, flex desc2) of
                      (_, Flexible) -> Just name1
                      (Flexible, _) -> Just name2
                      (Is Number, Is _) -> Just name1
                      (Is _, Is Number) -> Just name2
                      (Is _, Is _) -> Just name1
                      (_, _) -> Nothing
                (Just name1, _) -> Just name1
                (_, Just name2) -> Just name2
                _ -> Nothing

      flex' :: Flex
      flex' = case (flex desc1, flex desc2) of
                (f, Flexible) -> f
                (Flexible, f) -> f
                (Is Number, Is _) -> Is Number
                (Is _, Is Number) -> Is Number
                (Is super, Is _) -> Is super
                (_, _) -> Flexible

      rank' :: Int
      rank' = min (rank desc1) (rank desc2)

      merge1 :: StateT TS.SolverState IO ()
      merge1 = liftIO $ do
        if rank desc1 < rank desc2 then UF.union variable2 variable1
                                   else UF.union variable1 variable2
        UF.modifyDescriptor variable1 $ \desc ->
            desc { structure = structure desc1, flex = flex', name = name' }

      merge2 :: StateT TS.SolverState IO ()
      merge2 = liftIO $ do
        if rank desc1 < rank desc2 then UF.union variable2 variable1
                                   else UF.union variable1 variable2
        UF.modifyDescriptor variable2 $ \desc ->
            desc { structure = structure desc2, flex = flex', name = name' }

      merge = if rank desc1 < rank desc2 then merge1 else merge2

      fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable
      fresh structure = do
        var <- liftIO . UF.fresh $ Descriptor {
                 structure = structure, rank = rank', flex = flex',
                 name = name', copy = Nothing, mark = noMark
               }
        TS.register var

      flexAndUnify var = do
        liftIO $ UF.modifyDescriptor var $ \desc -> desc { flex = Flexible }
        unify' variable1 variable2

      unifyNumber svar name
          | name `elem` ["Int","Float","number"] = flexAndUnify svar
          | otherwise = TS.addError region (Just hint) variable1 variable2
          where hint = "A number must be an Int or Float."

      comparableError maybe =
          TS.addError region (Just $ Maybe.fromMaybe msg maybe) variable1 variable2
          where msg = "A comparable must be an Int, Float, Char, String, list, or tuple."

      unifyComparable var name
          | name `elem` ["Int","Float","Char","String","comparable"] = flexAndUnify var
          | otherwise = comparableError Nothing

      unifyComparableStructure varSuper varFlex =
          do struct <- liftIO $ collectApps varFlex
             case struct of
               Other -> comparableError Nothing
               List v -> do flexAndUnify varSuper
                            unify' v =<< liftIO (var $ Is Comparable)
               Tuple vs
                   | length vs > 6 ->
                       comparableError $ Just "Cannot compare a tuple with more than 6 elements."
                   | otherwise -> 
                       do flexAndUnify varSuper
                          cmpVars <- liftIO $ forM [1..length vs] $ \_ -> var (Is Comparable)
                          zipWithM_ unify' vs cmpVars

      unifyAppendable varSuper varFlex =
          do struct <- liftIO $ collectApps varFlex
             case struct of
               List _ -> flexAndUnify varSuper
               _ -> comparableError Nothing

      rigidError variable = TS.addError region (Just hint) variable1 variable2
          where
            var = "'" ++ render (pretty Never variable) ++ "'"
            hint = "Cannot unify rigid type variable " ++ var ++
                   ".\nThe problem probably relates to a type annotation. Note that rigid type\n\
                   \variables are not shared between a top-level and let-bound type annotations."

      superUnify =
          case (flex desc1, flex desc2, name desc1, name desc2) of
            (Is super1, Is super2, _, _)
                | super1 == super2 -> merge
            (Is Number, Is Comparable, _, _) -> merge1
            (Is Comparable, Is Number, _, _) -> merge2
                   
            (Is Number, _, _, Just name) -> unifyNumber variable1 name
            (_, Is Number, Just name, _) -> unifyNumber variable2 name

            (Is Comparable, _, _, Just name) -> unifyComparable variable1 name
            (_, Is Comparable, Just name, _) -> unifyComparable variable2 name
            (Is Comparable, _, _, _) -> unifyComparableStructure variable1 variable2
            (_, Is Comparable, _, _) -> unifyComparableStructure variable2 variable1

            (Is Appendable, _, _, Just ctor)
                | ctor `elem` ["Text.Text","String"] -> flexAndUnify variable1
            (_, Is Appendable, Just ctor, _)
                | ctor `elem` ["Text.Text","String"] -> flexAndUnify variable2
            (Is Appendable, _, _, _) -> unifyAppendable variable1 variable2
            (_, Is Appendable, _, _) -> unifyAppendable variable2 variable1

            (Rigid, _, _, _) -> rigidError variable1
            (_, Rigid, _, _) -> rigidError variable2
            _ -> TS.addError region Nothing variable1 variable2

  case (structure desc1, structure desc2) of
    (Nothing, Nothing) | flex desc1 == Flexible && flex desc1 == Flexible -> merge
    (Nothing, _) | flex desc1 == Flexible -> merge2
    (_, Nothing) | flex desc2 == Flexible -> merge1

    (Just (Var1 v), _) -> unify' v variable2
    (_, Just (Var1 v)) -> unify' v variable1

    (Nothing, _) -> superUnify
    (_, Nothing) -> superUnify

    (Just type1, Just type2) ->
        case (type1,type2) of
          (App1 term1 term2, App1 term1' term2') ->
              do merge
                 unify' term1 term1'
                 unify' term2 term2'
          (Fun1 term1 term2, Fun1 term1' term2') ->
              do merge
                 unify' term1 term1'
                 unify' term2 term2'

          (EmptyRecord1, EmptyRecord1) ->
              return ()

          (Record1 fields ext, EmptyRecord1) | Map.null fields -> unify' ext variable2
          (EmptyRecord1, Record1 fields ext) | Map.null fields -> unify' ext variable1

          (Record1 fields1 ext1, Record1 fields2 ext2) ->
              do sequence . concat . Map.elems $ Map.intersectionWith (zipWith unify') fields1 fields2
                 let mkRecord fs ext = fresh . Just $ Record1 fs ext
                 case (Map.null fields1', Map.null fields2') of
                   (True , True ) -> unify' ext1 ext2
                   (True , False) -> do
                      record2' <- mkRecord fields2' ext2
                      unify' ext1 record2'
                   (False, True ) -> do
                      record1' <- mkRecord fields1' ext1
                      unify' record1' ext2
                   (False, False) -> do
                      record1' <- mkRecord fields1' =<< fresh Nothing
                      record2' <- mkRecord fields2' =<< fresh Nothing
                      unify' record1' ext2
                      unify' ext1 record2'
              where
                fields1' = unmerged fields1 fields2
                fields2' = unmerged fields2 fields1

                unmerged a b = Map.filter (not . null) $ Map.union (Map.intersectionWith eat a b) a

                eat (_:xs) (_:ys) = eat xs ys
                eat xs _ = xs

          _ -> TS.addError region Nothing variable1 variable2