module Agda.TypeChecking.LevelConstraints ( simplifyLevelConstraint ) where
import qualified Data.List as List
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Maybe
import Agda.Syntax.Internal
import Agda.TypeChecking.Monad.Base
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Free
import Agda.TypeChecking.Level
import Agda.Utils.Impossible
import Agda.Utils.List (nubOn)
import Agda.Utils.Update
simplifyLevelConstraint
  :: Constraint          
  -> [Constraint]        
  -> Maybe [Constraint]  
                         
simplifyLevelConstraint c others = do
    cs <- inequalities c
    case runChange $ mapM simpl cs of
      (cs', True) -> Just cs'
      (_,  False) -> Nothing
  where
    simpl :: Leq -> Change (Constraint)
    simpl (a :=< b)
      | any (matchLeq (b :=< a)) leqs = dirty  $ LevelCmp CmpEq  (unSingleLevel a) (unSingleLevel b)
      | otherwise                     = return $ LevelCmp CmpLeq (unSingleLevel a) (unSingleLevel b)
    leqs = concat $ mapMaybe inequalities others
data Leq = SingleLevel :=< SingleLevel
  deriving (Show, Eq)
matchLeq :: Leq -> Leq -> Bool
matchLeq (a :=< b) (c :=< d)
  | length xs == length ys = (a, b) == applySubst rho (c, d)
  | otherwise              = False
  where
    free :: Free a => a -> [Int]
    free = nubOn id . runFree (:[]) IgnoreNot  
    xs  = free (a, b)
    ys  = free (c, d)
    rho = mkSub $ List.sort $ zip ys xs
    mkSub = go 0
      where
        go _ [] = IdS
        go y ren0@((y', x) : ren)
          | y == y'   = Var x [] :# go (y + 1) ren
          | otherwise = Strengthen __IMPOSSIBLE__ $ go (y + 1) ren0
inequalities :: Constraint -> Maybe [Leq]
inequalities (LevelCmp CmpLeq a b)
  | Just b' <- singleLevelView b = Just $ map (:=< b') $ NonEmpty.toList $ levelMaxView a
  
  
  
inequalities (LevelCmp CmpEq a b)
  | Just a' <- singleLevelView a =
  case break (== a') (NonEmpty.toList $ levelMaxView b) of
    (bs0, _ : bs1) -> Just [ b' :=< a' | b' <- bs0 ++ bs1 ]
    _              -> Nothing
inequalities (LevelCmp CmpEq a b)
  | Just b' <- singleLevelView b =
  case break (== b') (NonEmpty.toList $ levelMaxView a) of
    (as0, _ : as1) -> Just [ a' :=< b' | a' <- as0 ++ as1 ]
    _              -> Nothing
inequalities _ = Nothing