module Type.Solve (solve) where

import Control.Monad
import Control.Monad.State
import qualified Data.List as List
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import qualified Data.UnionFind.IO as UF
import Type.Type
import Type.Unify
import qualified Type.ExtraChecks as Check
import qualified Type.Hint as Hint
import qualified Type.State as TS
import qualified AST.Annotation as A


-- | Every variable has rank less than or equal to the maxRank of the pool.
--   This sorts variables into the young and old pools accordingly.
generalize :: TS.Pool -> StateT TS.SolverState IO ()
generalize youngPool =
  do  youngMark <- TS.uniqueMark 
      let youngRank = TS.maxRank youngPool
          insert dict var = do
            descriptor <- liftIO $ UF.descriptor var
            liftIO $ UF.modifyDescriptor var (\desc -> desc { mark = youngMark })
            return $ Map.insertWith (++) (rank descriptor) [var] dict

      -- Sort the youngPool variables by rank.
      rankDict <- foldM insert Map.empty (TS.inhabitants youngPool)

      -- get the ranks right for each entry.
      -- start at low ranks so that we only have to pass
      -- over the information once.
      visitedMark <- TS.uniqueMark
      forM (Map.toList rankDict) $ \(poolRank, vars) ->
          mapM (adjustRank youngMark visitedMark poolRank) vars

      -- For variables that have rank lowerer than youngRank, register them in
      -- the old pool if they are not redundant.
      let registerIfNotRedundant var = do
            isRedundant <- liftIO $ UF.redundant var
            if isRedundant then return var else TS.register var

      let rankDict' = Map.delete youngRank rankDict
      Traversable.traverse (mapM registerIfNotRedundant) rankDict'

      -- For variables with rank youngRank
      --   If rank < youngRank: register in oldPool
      --   otherwise generalize
      let registerIfLowerRank var = do
            isRedundant <- liftIO $ UF.redundant var
            case isRedundant of
              True -> return ()
              False -> do
                desc <- liftIO $ UF.descriptor var
                case rank desc < youngRank of
                  True -> TS.register var >> return ()
                  False -> do
                    let flex' = case flex desc of { Flexible -> Rigid ; other -> other }
                    liftIO $ UF.setDescriptor var (desc { rank = noRank, flex = flex' })
                                     
      mapM_ registerIfLowerRank (Map.findWithDefault [] youngRank rankDict)


-- adjust the ranks of variables such that ranks never increase as you
-- move deeper into a variable.
adjustRank :: Int -> Int -> Int -> Variable -> StateT TS.SolverState IO Int
adjustRank youngMark visitedMark groupRank var =
  do  descriptor <- liftIO $ UF.descriptor var
      case () of
        ()  | mark descriptor == youngMark ->
                do  -- Set the variable as marked first because it may be cyclic.
                    liftIO $ UF.modifyDescriptor var $ \desc ->
                        desc { mark = visitedMark }
                    rank' <- maybe (return groupRank) adjustTerm (structure descriptor)
                    liftIO $ UF.modifyDescriptor var $ \desc ->
                        desc { rank = rank' }
                    return rank'

            | mark descriptor /= visitedMark ->
                do  let rank' = min groupRank (rank descriptor)
                    liftIO $ UF.setDescriptor var (descriptor { mark = visitedMark, rank = rank' })
                    return rank'

            | otherwise ->
                return (rank descriptor)
  where
    adjust = adjustRank youngMark visitedMark groupRank

    adjustTerm term =
        case term of
          App1 a b -> max `liftM` adjust a `ap` adjust b
          Fun1 a b -> max `liftM` adjust a `ap` adjust b
          Var1 x -> adjust x
          EmptyRecord1 -> return outermostRank
          Record1 fields extension ->
              do ranks <- mapM adjust (concat (Map.elems fields))
                 rnk <- adjust extension
                 return . maximum $ rnk : ranks


solve :: TypeConstraint -> StateT TS.SolverState IO ()
solve (A.A region constraint) =
  case constraint of
    CTrue ->
      return ()

    CSaveEnv ->
      TS.saveLocalEnv

    CEqual term1 term2 ->
      do  t1 <- TS.flatten term1
          t2 <- TS.flatten term2
          unify region t1 t2

    CAnd cs ->
      mapM_ solve cs

    CLet [Scheme [] fqs constraint' _] (A.A _ CTrue) ->
      do  oldEnv <- TS.getEnv
          mapM TS.introduce fqs
          solve constraint'
          TS.modifyEnv (\_ -> oldEnv)

    CLet schemes constraint' ->
      do  oldEnv <- TS.getEnv
          headers <- Map.unions `fmap` mapM (solveScheme region) schemes
          TS.modifyEnv $ \env -> Map.union headers env
          solve constraint'
          mapM Check.occurs $ Map.toList headers
          TS.modifyEnv (\_ -> oldEnv)

    CInstance name term ->
      do  env <- TS.getEnv
          freshCopy <-
              case Map.lookup name env of
                Just tipe -> TS.makeInstance tipe
                Nothing
                  | List.isPrefixOf "Native." name -> liftIO (variable Flexible)
                  | otherwise ->
                      error ("Could not find '" ++ name ++ "' when solving type constraints.")

          t <- TS.flatten term
          unify region freshCopy t


solveScheme :: A.Region -> TypeScheme -> StateT TS.SolverState IO (Map.Map String Variable)
solveScheme region scheme =
    case scheme of
      Scheme [] [] constraint header ->
        do  solve constraint
            Traversable.traverse TS.flatten header

      Scheme rigidQuantifiers flexibleQuantifiers constraint header ->
        do  let quantifiers = rigidQuantifiers ++ flexibleQuantifiers
            oldPool <- TS.getPool

            -- fill in a new pool when working on this scheme's constraints
            freshPool <- TS.nextRankPool
            TS.switchToPool freshPool
            mapM TS.introduce quantifiers
            header' <- Traversable.traverse TS.flatten header
            solve constraint

            allDistinct region rigidQuantifiers
            youngPool <- TS.getPool
            TS.switchToPool oldPool
            generalize youngPool
            mapM (isGeneric region) rigidQuantifiers
            return header'


addHint
    :: A.Region
    -> Maybe String
    -> UF.Point Descriptor
    -> UF.Point Descriptor
    -> StateT TS.SolverState IO ()
addHint region hint t1 t2 =
  do  msg <- liftIO (Hint.create region hint t1 t2)
      TS.addHint msg


-- Checks that all of the given variables belong to distinct equivalence classes.
-- Also checks that their structure is Nothing, so they represent a variable, not
-- a more complex term.
allDistinct :: A.Region -> [Variable] -> StateT TS.SolverState IO ()
allDistinct region vars =
  do  seen <- TS.uniqueMark
      forM_ vars $ \var ->
        do  desc <- liftIO $ UF.descriptor var
            case structure desc of
              Just _ ->
                let msg = "Cannot generalize something that is not a type variable."
                in
                    addHint region (Just msg) var var

              Nothing ->
                do  when (mark desc == seen) $
                      do  let msg = "Duplicate variable during generalization."
                          addHint region (Just msg) var var
                    liftIO $ UF.setDescriptor var (desc { mark = seen })


-- Check that a variable has rank == noRank, meaning that it can be generalized.
isGeneric :: A.Region -> Variable -> StateT TS.SolverState IO ()
isGeneric region var =
  do  desc <- liftIO $ UF.descriptor var
      if rank desc == noRank
        then return ()
        else
          let msg = "Unable to generalize a type variable. It is not unranked."
          in  addHint region (Just msg) var var