-- Knuth-Bendix completion, with lots of exciting tricks for
-- unorientable equations.

{-# LANGUAGE CPP, TypeFamilies, FlexibleContexts, RecordWildCards, ScopedTypeVariables, UndecidableInstances, StandaloneDeriving, PatternGuards, BangPatterns #-}
module Twee where

#include "errors.h"
import Twee.Base hiding (empty, lookup)
import Twee.Constraints hiding (funs)
import Twee.Rule
import qualified Twee.Indexes as Indexes
import Twee.Indexes(Indexes, Rated(..))
import qualified Twee.Index as Index
import Twee.Index(Index, Frozen)
import Twee.Queue hiding (queue)
import Twee.Utils
import Control.Monad
import Data.Maybe
import Data.Ord
import qualified Debug.Trace
import Control.Monad.Trans.State.Strict
import Data.List
import Text.Printf
import qualified Data.Set as Set
import Data.Set(Set)
import Data.Either
import qualified Data.Map.Strict as Map
import Data.Map.Strict(Map)

--------------------------------------------------------------------------------
-- Completion engine state.
--------------------------------------------------------------------------------

data Twee f =
  Twee {
    maxSize           :: Maybe Int,
    labelledRules     :: {-# UNPACK #-} !(Indexes (Labelled (Modelled (Critical (Rule f))))),
    extraRules        :: {-# UNPACK #-} !(Indexes (Rule f)),
    cancellationRules :: !(Index (Labelled (CancellationRule f))),
    goals             :: [Set (Term f)],
    totalCPs          :: Int,
    processedCPs      :: Int,
    renormaliseAt     :: Int,
    minimumCPSetSize  :: Int,
    cpSplits          :: Int,
    queue             :: !(Queue (Mix (Either1 FIFO Heap)) (Passive f)),
    useGeneralSuperpositions :: Bool,
    useGroundJoining  :: Bool,
    useConnectedness  :: Bool,
    useSetJoining     :: Bool,
    useSetJoiningForGoals :: Bool,
    useCancellation :: Bool,
    maxCancellationSize :: Maybe Int,
    atomicCancellation :: Bool,
    unifyConstantsInCancellation :: Bool,
    useInterreduction :: Bool,
    useUnsafeInterreduction :: Bool,
    skipCompositeSuperpositions :: Bool,
    tracing :: Bool,
    moreTracing :: Bool,
    lhsWeight         :: Int,
    rhsWeight         :: Int,
    joinStatistics    :: Map JoinReason Int }
  deriving Show

initialState :: Int -> Int -> Twee f
initialState mixFIFO mixPrio =
  Twee {
    maxSize           = Nothing,
    labelledRules     = Indexes.empty,
    extraRules        = Indexes.empty,
    cancellationRules = Index.Nil,
    goals             = [],
    totalCPs          = 0,
    processedCPs      = 0,
    renormaliseAt     = 50,
    minimumCPSetSize  = 20,
    cpSplits          = 20,
    queue             = empty (emptyMix mixFIFO mixPrio (Left1 emptyFIFO) (Right1 emptyHeap)),
    useGeneralSuperpositions = True,
    useGroundJoining  = True,
    useConnectedness  = True,
    useSetJoining     = False,
    useSetJoiningForGoals = True,
    useInterreduction = False,
    useUnsafeInterreduction = True,
    useCancellation = True,
    atomicCancellation = True,
    maxCancellationSize = Nothing,
    unifyConstantsInCancellation = False,
    skipCompositeSuperpositions = True,
    tracing = True,
    moreTracing = False,
    lhsWeight         = 2,
    rhsWeight         = 1,
    joinStatistics    = Map.empty }

addGoals :: [Set (Term f)] -> Twee f -> Twee f
addGoals gs s = s { goals = gs ++ goals s }

report :: Function f => Twee f -> String
report Twee{..} =
  printf "Rules: %d total, %d oriented, %d unoriented, %d permutative, %d weakly oriented. "
    (length rs)
    (length [ () | Rule Oriented _ _ <- rs ])
    (length [ () | Rule Unoriented _ _ <- rs ])
    (length [ () | (Rule (Permutative _) _ _) <- rs ])
    (length [ () | (Rule (WeaklyOriented _) _ _) <- rs ]) ++
  printf "%d extra. %d historical.\n"
    (length (Indexes.elems extraRules))
    n ++
  printf "Critical pairs: %d total, %d processed, %d queued compressed into %d.\n\n"
    totalCPs
    processedCPs
    s
    (length (toList queue)) ++
  printf "Critical pairs joined:\n" ++
  concat [printf "%6d %s.\n" n (prettyShow x) | (x, n) <- Map.toList joinStatistics]
  where
    rs = map (critical . modelled . peel) (Indexes.elems labelledRules)
    Label n = nextLabel queue
    s = sum (map passiveCount (toList queue))

enqueueM :: Function f => Passive f -> State (Twee f) ()
enqueueM cps = do
  traceM (NewCP cps)
  modify' $ \s -> s {
    queue    = enqueue cps (queue s),
    totalCPs = totalCPs s + passiveCount cps }

reenqueueM :: Function f => Passive f -> State (Twee f) ()
reenqueueM cps = do
  modify' $ \s -> s {
    queue    = reenqueue cps (queue s) }

dequeueM :: Function f => State (Twee f) (Maybe (Passive f))
dequeueM =
  state $ \s ->
    case dequeue (queue s) of
      Nothing -> (Nothing, s)
      Just (x, q) -> (Just x, s { queue = q })

newLabelM :: State (Twee f) Label
newLabelM =
  state $ \s ->
    case newLabel (queue s) of
      (l, q) -> (l, s { queue = q })

data Modelled a =
  Modelled {
    model     :: Model (ConstantOf a),
    positions :: [Int],
    modelled  :: a }

instance Eq a => Eq (Modelled a) where x == y = modelled x == modelled y
instance Ord a => Ord (Modelled a) where compare = comparing modelled

instance (PrettyTerm (ConstantOf a), Pretty a) => Pretty (Modelled a) where
  pPrint Modelled{..} = pPrint modelled

deriving instance (Show a, Show (ConstantOf a)) => Show (Modelled a)

instance Symbolic a => Symbolic (Modelled a) where
  type ConstantOf (Modelled a) = ConstantOf a

  term = term . modelled
  termsDL = termsDL . modelled
  replace f Modelled{..} = Modelled model positions (replace f modelled)

--------------------------------------------------------------------------------
-- Rewriting.
--------------------------------------------------------------------------------

instance Rated a => Rated (Labelled a) where
  rating = rating . peel
  maxRating = maxRating . peel
instance Rated a => Rated (Modelled a) where
  rating = rating . modelled
  maxRating = maxRating . modelled
instance Rated a => Rated (Critical a) where
  rating = rating . critical
  maxRating = maxRating . critical
instance Rated (Rule f) where
  rating (Rule Oriented _ _) = 0
  rating (Rule WeaklyOriented{} _ _) = 0
  rating _ = 1
  maxRating _ = 1

{-# INLINE rulesFor #-}
rulesFor :: Function f => Int -> Twee f -> Frozen (Rule f)
rulesFor n k =
  Index.map (critical . modelled . peel) (Indexes.freeze n (labelledRules k))

easyRules, rules, allRules :: Function f => Twee f -> Frozen (Rule f)
easyRules k = rulesFor 0 k
rules k = rulesFor 1 k `Index.union` Indexes.freeze 0 (extraRules k)
allRules k = rulesFor 1 k `Index.union` Indexes.freeze 1 (extraRules k)

normaliseQuickly :: Function f => Twee f -> Term f -> Reduction f
normaliseQuickly s t = normaliseWith (rewrite "simplify" simplifies (easyRules s)) t

normalise :: Function f => Twee f -> Term f -> Reduction f
normalise s t = normaliseWith (rewrite "reduce" reduces (rules s)) t

normaliseIn :: Function f => Twee f -> Model f -> Term f -> Reduction f
normaliseIn s model t =
  normaliseWith (rewrite "model" (reducesInModel model) (rules s)) t

normaliseSub :: Function f => Twee f -> Term f -> Term f -> Reduction f
normaliseSub s top t
  | useConnectedness s && lessEq t top && isNothing (unify t top) =
    normaliseWith (rewrite "sub" (reducesSub top) (rules s)) t
  | otherwise = Parallel [] t

normaliseSkolem :: Function f => Twee f -> Term f -> Reduction f
normaliseSkolem s t = normaliseWith (rewrite "skolem" reducesSkolem (rules s)) t

reduceCP ::
  Function f =>
  Twee f -> JoinStage -> (Term f -> Term f) ->
  Critical (Equation f) -> Either JoinReason (Critical (Equation f))
reduceCP s stage f (Critical top (t :=: u))
  | t' == u' = Left (Trivial stage)
  | subsumed s t' u' = Left (Subsumed stage)
  | otherwise = Right (Critical top (t' :=: u'))
  where
    t' = f t
    u' = f u

    subsumed s t u = here || there t u
      where
        here =
          or [ rhs x == u | x <- Index.lookup t rs ]
        there (Var x) (Var y) | x == y = True
        there (Fun f ts) (Fun g us) | f == g = and (zipWith (subsumed s) (fromTermList ts) (fromTermList us))
        there _ _ = False
        rs = allRules s

data JoinStage = Initial | Simplification | Reducing | Subjoining deriving (Eq, Ord, Show)
data JoinReason = Trivial JoinStage | Subsumed JoinStage | SetJoining | GroundJoined deriving (Eq, Ord, Show)

instance Pretty JoinStage where
  pPrint Initial        = text "no rewriting"
  pPrint Simplification = text "simplification"
  pPrint Reducing       = text "reduction"
  pPrint Subjoining     = text "connectedness testing"

instance Pretty JoinReason where
  pPrint (Trivial stage)  = text "joined after" <+> pPrint stage
  pPrint (Subsumed stage) = text "subsumed after" <+> pPrint stage
  pPrint SetJoining       = text "joined with set of normal forms"
  pPrint GroundJoined     = text "ground joined"

normaliseCPQuickly, normaliseCPReducing, normaliseCP ::
  Function f =>
  Twee f -> Critical (Equation f) -> Either JoinReason (Critical (Equation f))
normaliseCPQuickly s cp =
  reduceCP s Initial id cp >>=
  reduceCP s Simplification (result . normaliseQuickly s)

normaliseCPReducing s cp =
  normaliseCPQuickly s cp >>=
  reduceCP s Reducing (result . normalise s)

normaliseCP s cp@(Critical info _) =
  case (cp1, cp2, cp3, cp4) of
    (Right cp, Right _, Right _, Right _) -> Right cp
    (Right _, Right _, Right _, Left x) -> Left x
    (Right _, Right _, Left x, _) -> Left x
    (Right _, Left x, _, _) -> Left x
    (Left x, _, _, _) -> Left x
  where
    cp1 =
      normaliseCPReducing s cp >>=
      reduceCP s Subjoining (result . normaliseSub s (top info))

    cp2 =
      normaliseCPReducing s cp >>=
      reduceCP s Subjoining (result . normaliseSub s (flipCP (top info))) . flipCP

    cp3 = setJoin cp
    cp4 = setJoin (flipCP cp)

    flipCP :: Symbolic a => a -> a
    flipCP t = replace (substList sub) t
      where
        n = maximum (0:map fromEnum (vars t))
        sub (MkVar x) = var (MkVar (n - x))

    -- XXX shouldn't this also check subsumption?
    setJoin (Critical info (t :=: u))
      | not (useSetJoining s) ||
        Set.null (norm t `Set.intersection` norm u) =
        Right (Critical info (t :=: u))
      | otherwise =
        Debug.Trace.traceShow (sep [text "Joined", nest 2 (pPrint (Critical info (t :=: u))), text "to", nest 2 (pPrint v)])
        Left SetJoining
      where
        norm t
          | lessEq t (top info) && isNothing (unify t (top info)) =
            normalForms (rewrite "setjoin" (reducesSub (top info)) (rules s)) [t]
          | otherwise = Set.singleton t
        v = Set.findMin (norm t `Set.intersection` norm u)

--------------------------------------------------------------------------------
-- Completion loop.
--------------------------------------------------------------------------------

complete :: Function f => State (Twee f) ()
complete = do
  res <- complete1
  when res complete

complete1 :: Function f => State (Twee f) Bool
complete1 = do
  Twee{..} <- get
  let Label n = nextLabel queue
  when (n >= renormaliseAt) $ do
    normaliseCPs
    modify (\s -> s { renormaliseAt = renormaliseAt * 3 `div` 2 })

  res <- dequeueM
  case res of
    Just (SingleCP (CP info cp l1 l2)) -> do
      res <- consider (cpWeight info) l1 l2 cp
      when res renormaliseGoals
      return True
    Just (ManyCPs (CPs _ l lower upper size rule)) -> do
      s <- get
      modify (\s@Twee{..} -> s { totalCPs = totalCPs - size })

      queueCPsSplit reenqueueM lower (l-1) rule
      mapM_ (reenqueueM . SingleCP) (toCPs s l l rule)
      queueCPsSplit reenqueueM (l+1) upper rule
      complete1
    Nothing ->
      return False

renormaliseGoals :: Function f => State (Twee f) ()
renormaliseGoals = do
  Twee{..} <- get
  if useSetJoiningForGoals then
    modify $ \s -> s { goals = map (normalForms (rewrite "goal" reduces (rules s)) . Set.toList) goals }
  else
    modify $ \s -> s { goals = map (Set.fromList . map (result . normaliseWith (rewrite "goal" reduces (rules s))) . Set.toList) goals }

normaliseCPs :: forall f. Function f => State (Twee f) ()
normaliseCPs = do
  s@Twee{..} <- get
  traceM (NormaliseCPs s)
  put s { queue = emptyFrom queue }
  forM_ (toList queue) $ \cp ->
    case cp of
      SingleCP (CP _ cp l1 l2) -> queueCP enqueueM trivial l1 l2 cp
      ManyCPs (CPs _ _ lower upper _ rule) -> queueCPs enqueueM lower upper (const ()) rule
  modify (\s -> s { totalCPs = totalCPs })

consider ::
  Function f =>
  Int -> Label -> Label -> Critical (Equation f) -> State (Twee f) Bool
consider w l1 l2 pair = do
  traceM (Consider pair)
  modify' (\s -> s { processedCPs = processedCPs s + 1 })
  s <- get
  let record reason = modify' (\s -> s { joinStatistics = Map.insertWith (+) reason 1 (joinStatistics s) })
      hard (Trivial Subjoining) = True
      hard (Subsumed Subjoining) = True
      hard SetJoining = True
      hard _ = False
      tooBig (Critical _ (t :=: u)) =
        case maxSize s of
          Nothing -> False
          Just sz -> size t > sz || size u > sz
  if tooBig pair then return False else
    case normaliseCP s pair of
      Left reason -> do
        record reason
        when (hard reason) $ forM_ (map canonicalise (orient (critical pair))) $ \(Rule _ t u0) -> do
          s <- get
          let u = result (normaliseSub s t u0)
              r = rule t u
          addExtraRule r
        traceM (Joined pair reason)
        return False
      Right pair | tooBig pair ->
        return False
      Right pair@(Critical _ eq)
        | cancelledWeight s (groundJoinableEq s) eq > w -> do
          traceM (Delay pair)
          queueCP enqueueM (groundJoinableEq s) l1 l2 pair
          return False
      Right pair@(Critical _ eq)
        | (_, eq') <- bestCancellation s (groundJoinableEq s) eq,
          eq /= eq' -> do
            traceM (Cancel pair eq')
            res <- consider maxBound l1 l2 (Critical noCritInfo eq')
            s <- get
            queueCP enqueueM (groundJoinableEq s) l1 l2 pair
            return res
      Right (Critical info eq) ->
        fmap or $ forM (map canonicalise (orient eq)) $ \r0@(Rule _ t u0) -> do
          s <- get
          let u = result (normaliseSub s t u0)
              r = rule t u
              info' = info { top = t }
          case normaliseCP s (Critical info' (t :=: u)) of
            Left reason -> do
              when (hard reason) $ record reason
              addExtraRule r
              addExtraRule r0
              return False
            Right eq ->
              case groundJoin s (branches (And [])) eq of
                Right eqs -> do
                  record GroundJoined
                  mapM_ (consider maxBound l1 l2) [ eq { critInfo = info' } | eq <- eqs ]
                  addExtraRule r
                  addExtraRule r0
                  return False
                Left model -> do
                  traceM (NewRule r)
                  l <- addRule (Modelled model (ruleOverlaps s (lhs r)) (Critical info r))
                  queueCPsSplit enqueueM noLabel l (Labelled l r)
                  interreduce r
                  return True

groundJoinableEq :: Function f => Twee f -> Equation f -> Bool
groundJoinableEq s eq = groundJoinable s (Critical noCritInfo eq)

groundJoinable :: Function f => Twee f -> Critical (Equation f) -> Bool
groundJoinable s pair =
  case normaliseCP s pair of
    Left _ -> True
    Right pair' ->
      case groundJoin s (branches (And [])) pair' of
        Left _ -> False
        Right pairs -> all (groundJoinable s) pairs

groundJoin :: Function f =>
  Twee f -> [Branch f] -> Critical (Equation f) -> Either (Model f) [Critical (Equation f)]
groundJoin s ctx r@(Critical info (t :=: u)) =
  case partitionEithers (map (solve (usort (atoms t ++ atoms u))) ctx) of
    ([], instances) ->
      let rs = [ subst sub r | sub <- instances ] in
      Right (usort (map canonicalise rs))
    (model:_, _)
      | not (useGroundJoining s) -> Left model
      | isRight (normaliseCP s (Critical info (t' :=: u'))) -> Left model
      | otherwise ->
          let model1 = optimise model weakenModel (\m -> valid m nt && valid m nu)
              model2 = optimise model1 weakenModel (\m -> isLeft (normaliseCP s (Critical info (result (normaliseIn s m t) :=: result (normaliseIn s m u)))))

              diag [] = Or []
              diag (r:rs) = negateFormula r ||| (weaken r &&& diag rs)
              weaken (LessEq t u) = Less t u
              weaken x = x
              ctx' = formAnd (diag (modelToLiterals model2)) ctx in

          trace s (Discharge r model2) $
          groundJoin s ctx' r
      where
        nt = normaliseIn s model t
        nu = normaliseIn s model u
        t' = result nt
        u' = result nu

valid :: Function f => Model f -> Reduction f -> Bool
valid model red = all valid1 (steps red)
  where
    valid1 (rule, sub) = reducesInModel model rule sub

optimise :: a -> (a -> [a]) -> (a -> Bool) -> a
optimise x f p =
  case filter p (f x) of
    y:_ -> optimise y f p
    _   -> x

addRule :: Function f => Modelled (Critical (Rule f)) -> State (Twee f) Label
addRule rule = do
  l <- newLabelM
  modify (\s -> s { labelledRules = Indexes.insert (Labelled l rule) (labelledRules s) })
  modify (addCancellationRule l (critical (modelled rule)))
  return l

addExtraRule :: Function f => Rule f -> State (Twee f) ()
addExtraRule rule = do
  s <- get
  when (extraRuleSafe s rule) $ do
    traceM (ExtraRule rule)
    modify (\s -> s { extraRules = Indexes.insert rule (extraRules s) })

extraRuleSafe :: Function f => Twee f -> Rule f -> Bool
extraRuleSafe s _ | useUnsafeInterreduction s = True
extraRuleSafe s (Rule _ l _) =
  null $ do
    Index.Match (Rule _ l' _) _ <- Index.matches l (allRules s)
    guard (l' `isInstanceOf` l)

deleteRule :: Function f => Label -> Modelled (Critical (Rule f)) -> State (Twee f) ()
deleteRule l rule = do
  modify $ \s ->
    s { labelledRules = Indexes.delete (Labelled l rule) (labelledRules s),
        queue = deleteLabel l (queue s) }
  modify (deleteCancellationRule l (critical (modelled rule)))

data Simplification f = Simplify (Model f) (Modelled (Critical (Rule f))) | Reorient (Modelled (Critical (Rule f))) deriving Show

instance (Numbered f, PrettyTerm f) => Pretty (Simplification f) where
  pPrint (Simplify _ rule) = text "Simplify" <+> pPrint rule
  pPrint (Reorient rule) = text "Reorient" <+> pPrint rule

interreduce :: Function f => Rule f -> State (Twee f) ()
interreduce new = do
  rules <- gets (\s -> Indexes.elems (labelledRules s))
  forM_ rules $ \(Labelled l old) -> do
    s <- get
    case reduceWith s l new old of
      Nothing -> return ()
      Just red -> do
        traceM (Reduce red new)
        case red of
          Simplify model rule -> simplifyRule l model rule
          Reorient rule@(Modelled _ _ (Critical info (Rule _ t u))) ->
            when (useInterreduction s) $ do
              deleteRule l rule
              consider maxBound noLabel noLabel (Critical info (t :=: u))
              return ()

reduceWith :: Function f => Twee f -> Label -> Rule f -> Modelled (Critical (Rule f)) -> Maybe (Simplification f)
reduceWith s lab new old0@(Modelled model _ (Critical info old@(Rule _ l r)))
  | not (isWeak new) &&
    not (lhs new `isInstanceOf` l) &&
    not (null (anywhere (tryRule reduces new) l)) =
      Just (Reorient old0)
  | not (isWeak new) &&
    not (lhs new `isInstanceOf` l) &&
    not (oriented (orientation new)) &&
    not (all isNothing [ match (lhs new) l' | l' <- subterms l ]) &&
    modelJoinable =
    tryGroundJoin
  | not (null (anywhere (tryRule reduces new) (rhs old))) =
      Just (Simplify model old0)
  | not (oriented (orientation old)) &&
    not (oriented (orientation new)) &&
    not (lhs new `isInstanceOf` r) &&
    not (all isNothing [ match (lhs new) r' | r' <- subterms r ]) &&
    modelJoinable =
    tryGroundJoin
  | otherwise = Nothing
  where
    s' = s { labelledRules = Indexes.delete (Labelled lab old0) (labelledRules s) }
    modelJoinable = isLeft (normaliseCP s' (Critical info (lm :=: rm)))
    lm = result (normaliseIn s' model l)
    rm = result (normaliseIn s' model r)
    tryGroundJoin =
      case groundJoin s' (branches (And [])) (Critical info (l :=: r)) of
        Left model' ->
          Just (Simplify model' old0)
        Right _ ->
          Just (Reorient old0)
    isWeak (Rule (WeaklyOriented _) _ _) = True
    isWeak _ = False

simplifyRule :: Function f => Label -> Model f -> Modelled (Critical (Rule f)) -> State (Twee f) ()
simplifyRule l model r@(Modelled _ positions (Critical info (Rule _ lhs rhs))) = do
  modify $ \s ->
    s {
      labelledRules =
         Indexes.insert (Labelled l (Modelled model positions (Critical info (rule lhs (result (normalise s rhs))))))
           (Indexes.delete (Labelled l r) (labelledRules s)) }
  modify (deleteCancellationRule l (critical (modelled r)))
  modify (addCancellationRule l (critical (modelled r)))

newEquation :: Function f => Equation f -> State (Twee f) ()
newEquation (t :=: u) = do
  consider maxBound noLabel noLabel (Critical noCritInfo (t :=: u))
  renormaliseGoals
  return ()

noCritInfo :: Function f => CritInfo f
noCritInfo = CritInfo minimalTerm 0

--------------------------------------------------------------------------------
-- Cancellation rules.
--------------------------------------------------------------------------------

data CancellationRule f =
  CancellationRule {
    cr_unified :: [[Term f]],
    cr_rule :: {-# UNPACK #-} !(Rule f) }
  deriving Show

instance (Numbered f, PrettyTerm f) => Pretty (CancellationRule f) where
  pPrint (CancellationRule tss rule) =
    pPrint rule <+> text "cancelling" <+> pPrint tss

instance Symbolic (CancellationRule f) where
  type ConstantOf (CancellationRule f) = f
  term (CancellationRule _ rule) = term rule
  termsDL (CancellationRule tss rule) =
    termsDL rule `mplus` termsDL tss
  replace sub (CancellationRule tss rule) =
    CancellationRule (replace sub tss) (replace sub rule)

toCancellationRule :: Function f => Twee f -> Rule f -> Maybe (CancellationRule f)
toCancellationRule _ (Rule Permutative{} _ _) = Nothing
toCancellationRule _ (Rule WeaklyOriented{} _ _) = Nothing
toCancellationRule s (Rule or l r)
  | not (null vs) &&
    (not (atomicCancellation s) || atomic r) =
    Just (CancellationRule tss (Rule or' l' r))
  | otherwise = Nothing
  where
    consts = unifyConstantsInCancellation s
    atomic (Var _) = True
    atomic (Fun _ Empty) = True
    atomic _ = False

    -- Variables that occur on lhs more than once, but not rhs
    vs = usort (vars l \\ usort (vars l)) \\ usort (vars r)
    cs = usort [ c | consts, Fun c Empty <- subterms l ]

    n = bound l `max` bound r

    l' = build (freshenVars (n + length cs) (singleton l))
    freshenVars !_ Empty = mempty
    freshenVars n (Cons (Var x) ts) =
      var y `mappend` freshenVars (n+1) ts
      where
        y = if x `elem` vs then MkVar n else x
    freshenVars i (Cons (Fun f Empty) ts) | f `elem` cs =
      var (MkVar m) `mappend` freshenVars (i+1) ts
      where
        m = n + fromMaybe __ (elemIndex f cs)
    freshenVars n (Cons (Fun f ts) us) =
      fun f (freshenVars (n+1) ts) `mappend`
      freshenVars (n+lenList ts+1) us

    tss =
      map (map (build . var . snd)) (partitionBy fst pairs) ++
      zipWith (\i c -> [build (con c), build (var (MkVar i))]) [n..] cs
    pairs = concat (zipWith f (subterms l) (subterms l'))
      where
        f (Var x) (Var y)
          | x `elem` vs = [(x, y)]
        f _ _ = []

    or' = subst (var . f) or
      where
        f x = fromMaybe __ (lookup x pairs)

addCancellationRule :: Function f => Label -> Rule f -> Twee f -> Twee f
addCancellationRule _ (Rule _ t u) s
  | Just n <- maxCancellationSize s, size (t :=: u) > n = s
addCancellationRule l r s =
  case toCancellationRule s r of
    Nothing -> s
    Just c
      | moreTracing s &&
        Debug.Trace.traceShow (sep [text "Adding cancellation rule", nest 2 (pPrint c)]) False -> __
    Just c -> s {
      cancellationRules =
          Index.insert (Labelled l c) (cancellationRules s) }

deleteCancellationRule :: Function f => Label -> Rule f -> Twee f -> Twee f
deleteCancellationRule l r s =
  case toCancellationRule s r of
    Nothing -> s
    Just c -> s {
      cancellationRules =
          Index.delete (Labelled l c) (cancellationRules s) }

--------------------------------------------------------------------------------
-- Critical pairs.
--------------------------------------------------------------------------------

data Critical a =
  Critical {
    critInfo :: CritInfo (ConstantOf a),
    critical :: a }

data CritInfo f =
  CritInfo {
    top      :: Term f,
    overlap  :: Int }

instance Eq a => Eq (Critical a) where x == y = critical x == critical y
instance Ord a => Ord (Critical a) where compare = comparing critical

instance (PrettyTerm (ConstantOf a), Pretty a) => Pretty (Critical a) where
  pPrint Critical{..} = pPrint critical

deriving instance (Show a, Show (ConstantOf a)) => Show (Critical a)
deriving instance Show f => Show (CritInfo f)

instance Symbolic a => Symbolic (Critical a) where
  type ConstantOf (Critical a) = ConstantOf a

  term = term . critical
  termsDL Critical{..} = termsDL (critical, critInfo)
  replace f Critical{..} = Critical (replace f critInfo) (replace f critical)

instance Symbolic (CritInfo f) where
  type ConstantOf (CritInfo f) = f

  term = __
  termsDL = termsDL . top
  replace f CritInfo{..} = CritInfo (replace f top) overlap

data CPInfo =
  CPInfo {
    cpWeight  :: {-# UNPACK #-} !Int,
    cpWeight2 :: {-# UNPACK #-} !Int,
    cpAge1    :: {-# UNPACK #-} !Label,
    cpAge2    :: {-# UNPACK #-} !Label }
    deriving (Eq, Ord, Show)

data CP f =
  CP {
    info :: {-# UNPACK #-} !CPInfo,
    cp   :: {-# UNPACK #-} !(Critical (Equation f)),
    l1   :: {-# UNPACK #-} !Label,
    l2   :: {-# UNPACK #-} !Label }
  deriving Show

instance Eq (CP f) where x == y = info x == info y
instance Ord (CP f) where compare = comparing info
instance Labels (CP f) where labels x = [l1 x, l2 x]
instance (Numbered f, PrettyTerm f) => Pretty (CP f) where
  pPrint = pPrint . cp

data CPs f =
  CPs {
    best  :: {-# UNPACK #-} !CPInfo,
    label :: {-# UNPACK #-} !Label,
    lower :: {-# UNPACK #-} !Label,
    upper :: {-# UNPACK #-} !Label,
    count :: {-# UNPACK #-} !Int,
    from  :: {-# UNPACK #-} !(Labelled (Rule f)) }
  deriving Show

instance Eq (CPs f) where x == y = best x == best y
instance Ord (CPs f) where compare = comparing best
instance Labels (CPs f) where labels (CPs _ _ _ _ _ (Labelled l _)) = [l]
instance (Numbered f, PrettyTerm f) => Pretty (CPs f) where
  pPrint CPs{..} = text "Family of size" <+> pPrint count <+> text "from" <+> pPrint from

data Passive f =
    SingleCP {-# UNPACK #-} !(CP f)
  | ManyCPs  {-# UNPACK #-} !(CPs f)
  deriving (Eq, Show)

instance Ord (Passive f) where
  compare = comparing f
    where
      f (SingleCP x) = info x
      f (ManyCPs  x) = best x
instance Labels (Passive f) where
  labels (SingleCP x) = labels x
  labels (ManyCPs x) = labels x
instance (Numbered f, PrettyTerm f) => Pretty (Passive f) where
  pPrint (SingleCP cp) = pPrint cp
  pPrint (ManyCPs cps) = pPrint cps

passiveCount :: Passive f -> Int
passiveCount SingleCP{} = 1
passiveCount (ManyCPs x) = count x

data InitialCP f =
  InitialCP {
    cpId :: (Term f, Label),
    cpOK :: Bool,
    cpCP :: Labelled (Critical (Equation f)) }

criticalPairs :: Function f => Twee f -> Label -> Label -> Rule f -> [Labelled (Critical (Equation f))]
criticalPairs s lower upper rule =
  criticalPairs1 s (ruleOverlaps s (lhs rule)) rule (map (fmap (critical . modelled)) rules) ++
  [ cp
  | Labelled l' (Modelled _ ns (Critical _ old)) <- rules,
    cp <- criticalPairs1 s ns old [Labelled l' rule] ]
  where
    rules = filter (p . labelOf) (Indexes.elems (labelledRules s))
    p l = lower <= l && l <= upper

ruleOverlaps :: Twee f -> Term f -> [Int]
ruleOverlaps s t = aux 0 Set.empty (singleton t)
  where
    aux !_ !_ Empty = []
    aux n m (Cons (Var _) t) = aux (n+1) m t
    aux n m (ConsSym t@Fun{} u)
      | useGeneralSuperpositions s && t `Set.member` m = aux (n+1) m u
      | otherwise = n:aux (n+1) (Set.insert t m) u

overlaps :: [Int] -> Term f -> Term f -> [(Subst f, Int)]
overlaps ns t1 t2@(Fun g _) = go 0 ns (singleton t1) []
  where
    go !_ _ !_ _ | False = __
    go _ [] _ rest = rest
    go _ _ Empty rest = rest
    go n (m:ms) (ConsSym ~t@(Fun f _) u) rest
      | m == n && f == g = here ++ go (n+1) ms u rest
      | m == n = go (n+1) ms u rest
      | otherwise = go (n+1) (m:ms) u rest
      where
        here =
          case unify t t2 of
            Nothing -> []
            Just sub -> [(sub, n)]
overlaps _ _ _ = []

emitReplacement :: Int -> Term f -> TermList f -> Builder f
emitReplacement n t = aux n
  where
    aux !_ !_ | False = __
    aux _ Empty = mempty
    aux 0 (Cons _ u) = builder t `mappend` builder u
    aux n (Cons (Var x) u) = var x `mappend` aux (n-1) u
    aux n (Cons t@(Fun f ts) u)
      | n < len t =
        fun f (aux (n-1) ts) `mappend` builder u
      | otherwise =
        builder t `mappend` aux (n-len t) u

criticalPairs1 :: Function f => Twee f -> [Int] -> Rule f -> [Labelled (Rule f)] -> [Labelled (Critical (Equation f))]
criticalPairs1 s ns r rs = do
  let b = maximum (0:[ bound t | Labelled _ (Rule _ t _) <- rs ])
      Rule or t u = subst (\(MkVar x) -> var (MkVar (x+b))) r
  Labelled l (Rule or' t' u') <- rs
  (sub, pos) <- overlaps ns t t'
  let left = subst sub u
      right = subst sub (build (emitReplacement pos u' (singleton t)))
      top = subst sub t
      overlap = at pos (singleton t)

      inner = subst sub overlap
      osz = size overlap + (size u - size t) + (size u' - size t')

  guard (left /= top && right /= top && left /= right)
  when (or  /= Oriented) $ guard (not (lessEq top right))
  when (or' /= Oriented) $ guard (not (lessEq top left))
  when (skipCompositeSuperpositions s) $
    guard (null (nested (anywhere (rewrite "prime" simplifies (easyRules s))) inner))
  return (Labelled l (Critical (CritInfo top osz) (left :=: right)))

queueCP ::
  Function f =>
  (Passive f -> State (Twee f) ()) ->
  (Equation f -> Bool) -> Label -> Label -> Critical (Equation f) -> State (Twee f) ()
queueCP enq joinable l1 l2 eq = do
  s <- get
  case toCP s l1 l2 joinable eq of
    Nothing -> return ()
    Just cp -> enq (SingleCP cp)

queueCPs ::
  (Function f, Ord a) =>
  (Passive f -> State (Twee f) ()) ->
  Label -> Label -> (Label -> a) -> Labelled (Rule f) -> State (Twee f) ()
queueCPs enq lower upper f rule = do
  s <- get
  let cps = toCPs s lower upper rule
      cpss = partitionBy (f . l2) cps
  forM_ cpss $ \xs -> do
    if length xs <= minimumCPSetSize s then
      mapM_ (enq . SingleCP) xs
    else
      let best = minimum xs
          l1' = minimum (map l1 xs)
          l2' = minimum (map l2 xs) in
      enq (ManyCPs (CPs (info best) (l2 best) l1' l2' (length xs) rule))

queueCPsSplit ::
  Function f =>
  (Passive f -> State (Twee f) ()) ->
  Label -> Label -> Labelled (Rule f) -> State (Twee f) ()
queueCPsSplit enq l u rule = do
  s <- get
  let f x = fromIntegral (cpSplits s)*(x-l) `div` (u-l+1)
  queueCPs enq l u f rule

toCPs ::
  Function f =>
  Twee f -> Label -> Label -> Labelled (Rule f) -> [CP f]
toCPs s lower upper (Labelled l rule) =
  catMaybes [toCP s l l' trivial eqn | Labelled l' eqn <- criticalPairs s lower upper rule]

toCP ::
  Function f =>
  Twee f -> Label -> Label -> (Equation f -> Bool) -> Critical (Equation f) -> Maybe (CP f)
toCP s l1 l2 joinable cp = fmap toCP' (norm cp)
  where
    norm (Critical info (t :=: u)) = do
      guard (t /= u)
      let t' = result (normaliseQuickly s t)
          u' = result (normaliseQuickly s u)
          eq' = Critical info (t' :=: u')
      guard (t' /= u')
      return eq'

    toCP' eq@(Critical info (t :=: u)) =
      CP (CPInfo w (-(overlap info)) l2 l1) eq l1 l2
      where
        w = cancelledWeight s joinable (t :=: u)

cancelledWeight :: Function f => Twee f -> (Equation f -> Bool) -> Equation f -> Int
cancelledWeight s joinable eq = fst (bestCancellation s joinable eq)

bestCancellation :: Function f => Twee f -> (Equation f -> Bool) -> Equation f -> (Int, Equation f)
bestCancellation s _ eq | not (useCancellation s) = (weight s eq, eq)
bestCancellation s joinable (t :=: u) = (w, best)
  where
    cs   = cancellations s joinable (t :=: u)
    ws   = zipWith (+) [0..] (map (weight s) cs)
    w    = minimum ws
    best = snd (minimumBy (comparing fst) (zip ws cs))

weight, weight' :: Function f => Twee f -> Equation f -> Int
weight s eq = weight' s (order eq)

weight' s (t :=: u) =
  lhsWeight s*size' t + rhsWeight s*size' u
  where
    size' t = 4*(size t + len t) - length (vars t) - length (nub (vars t))

cancellations :: Function f => Twee f -> (Equation f -> Bool) -> Equation f -> [Equation f]
cancellations s joinable (t :=: u) =
  t :=: u:
  case cands of
    [] -> []
    _  -> cancellations s joinable (minimumBy (comparing size) cands)
  where
    cands =
      filter (\eq -> size eq < size (t :=: u)) $
      [ t' :=: u' | (sub, t') <- cancel t, let u' = result (normaliseQuickly s (subst sub u)), not (joinable (t' :=: u')) ] ++
      [ t' :=: u' | (sub, u') <- cancel u, let t' = result (normaliseQuickly s (subst sub t)), not (joinable (t' :=: u')) ]
    cancel t = do
      (i, u) <- zip [0..] (subterms t)
      Labelled _ (CancellationRule tss (Rule _ _ u')) <-
        Index.lookup u (Index.freeze (cancellationRules s))
      sub <- maybeToList (unifyMany [(t, u) | t:ts <- tss, u <- ts])
      let t' = result (normaliseQuickly s (subst sub (build (emitReplacement i u' (singleton t)))))
      return (sub, t')

    unifyMany ps =
      unifyList (buildList (map fst ps)) (buildList (map snd ps))

--------------------------------------------------------------------------------
-- Tracing.
--------------------------------------------------------------------------------

data Event f =
    NewRule (Rule f)
  | ExtraRule (Rule f)
  | NewCP (Passive f)
  | Reduce (Simplification f) (Rule f)
  | Consider (Critical (Equation f))
  | Joined (Critical (Equation f)) JoinReason
  | Delay (Critical (Equation f))
  | Cancel (Critical (Equation f)) (Equation f)
  | Discharge (Critical (Equation f)) (Model f)
  | NormaliseCPs (Twee f)

trace :: Function f => Twee f -> Event f -> a -> a
trace Twee{..} (NewRule rule) = traceIf tracing (hang (text "New rule") 2 (pPrint rule))
trace Twee{..} (ExtraRule rule) = traceIf tracing (hang (text "Extra rule") 2 (pPrint rule))
trace Twee{..} (NewCP cp) = traceIf moreTracing (hang (text "Critical pair") 2 (pPrint cp))
trace Twee{..} (Reduce red rule) = traceIf tracing (sep [pPrint red, nest 2 (text "using"), nest 2 (pPrint rule)])
trace Twee{..} (Consider eq) = traceIf moreTracing (sep [text "Considering", nest 2 (pPrint eq), text "under", nest 2 (pPrint (top (critInfo eq)))])
trace Twee{..} (Joined eq reason) = traceIf moreTracing (sep [text "Joined", nest 2 (pPrint eq), text "under", nest 2 (pPrint (top (critInfo eq))), text "by", nest 2 (pPrint reason)])
trace Twee{..} (Delay eq) = traceIf moreTracing (sep [text "Delaying", nest 2 (pPrint eq)])
trace Twee{..} (Cancel eq eq') = traceIf tracing (sep [text "Cancelled", nest 2 (pPrint eq), text "into", nest 2 (pPrint eq')])
trace Twee{..} (Discharge eq fs) = traceIf tracing (sep [text "Discharge", nest 2 (pPrint eq), text "under", nest 2 (pPrint fs)])
trace Twee{..} (NormaliseCPs s) = traceIf tracing (text "" $$ text "Normalising unprocessed critical pairs." $$ text (report s) $$ text "")

traceM :: Function f => Event f -> State (Twee f) ()
traceM x = do
  s <- get
  trace s x (return ())

traceIf :: Bool -> Doc -> a -> a
traceIf True x = Debug.Trace.trace (show x)
traceIf False _ = id