{-# LANGUAGE TypeFamilies, StandaloneDeriving, FlexibleContexts, UndecidableInstances, RecordWildCards, PatternGuards, CPP, BangPatterns #-}
module Twee.Rule where

#include "errors.h"
import Twee.Base
import Twee.Constraints
import qualified Twee.Index as Index
import Twee.Index(Frozen)
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict
import Data.Maybe
import Data.List
import Twee.Utils
import qualified Data.Set as Set
import Data.Set(Set)
import qualified Twee.Term as Term

--------------------------------------------------------------------------------
-- Rewrite rules.
--------------------------------------------------------------------------------

data Rule f =
  Rule {
    orientation :: Orientation f,
    lhs :: Term f,
    rhs :: Term f }
  deriving (Eq, Ord, Show)

data Orientation f =
    Oriented
  | WeaklyOriented [Term f]
  | Permutative [(Term f, Term f)]
  | Unoriented
  deriving Show

instance Eq (Orientation f) where _ == _ = True
instance Ord (Orientation f) where compare _ _ = EQ

oriented :: Orientation f -> Bool
oriented Oriented = True
oriented (WeaklyOriented _) = True
oriented _ = False

instance Symbolic (Rule f) where
  type ConstantOf (Rule f) = f
  term = lhs
  termsDL Rule{..} = termsDL (lhs, (rhs, orientation))
  replace f (Rule or l r) = Rule (replace f or) (replace f l) (replace f r)

instance Symbolic (Orientation f) where
  type ConstantOf (Orientation f) = f
  term = __
  termsDL Oriented = mempty
  termsDL (WeaklyOriented ts) = termsDL ts
  termsDL (Permutative ts) = termsDL ts
  termsDL Unoriented = mempty
  replace _ Oriented = Oriented
  replace f (WeaklyOriented ts) = WeaklyOriented (replace f ts)
  replace f (Permutative ts) = Permutative (replace f ts)
  replace _ Unoriented = Unoriented

instance (Numbered f, PrettyTerm f) => Pretty (Rule f) where
  pPrint (Rule Oriented l r) = pPrintRule l r
  pPrint (Rule (WeaklyOriented ts) l r) = hang (pPrintRule l r) 2 (text "(weak on" <+> pPrint ts <> text ")")
  pPrint (Rule (Permutative ts) l r) = hang (pPrintRule l r) 2 (text "(permutative on" <+> pPrint ts <> text ")")
  pPrint (Rule Unoriented l r) = hang (pPrintRule l r) 2 (text "(unoriented)")

pPrintRule :: (Numbered f, PrettyTerm f) => Term f -> Term f -> Doc
pPrintRule l r = hang (pPrint l <+> text "->") 2 (pPrint r)

--------------------------------------------------------------------------------
-- Equations.
--------------------------------------------------------------------------------

data Equation f = Term f :=: Term f deriving (Eq, Ord, Show)
type EquationOf a = Equation (ConstantOf a)

instance Symbolic (Equation f) where
  type ConstantOf (Equation f) = f
  term = __
  termsDL (t :=: u) = termsDL (t, u)
  replace f (t :=: u) = replace f t :=: replace f u

instance (Numbered f, PrettyTerm f) => Pretty (Equation f) where
  pPrint (x :=: y) = hang (pPrint x <+> text "=") 2 (pPrint y)

instance (Numbered f, Sized f) => Sized (Equation f) where
  size (x :=: y) = size x + size y

order :: Function f => Equation f -> Equation f
order (l :=: r)
  | l == r = l :=: r
  | otherwise =
    case compare (size l) (size r) of
      LT -> r :=: l
      GT -> l :=: r
      EQ -> if lessEq l r then r :=: l else l :=: r

unorient :: Rule f -> Equation f
unorient (Rule _ l r) = l :=: r

orient :: Function f => Equation f -> [Rule f]
orient (l :=: r) | l == r = []
orient (l :=: r) =
  -- If we have an equation where some variables appear only on one side, e.g.:
  --   f x y = g x z
  -- then replace it with the equations:
  --   f x y = f x k
  --   g x z = g x k
  --   f x k = g x k
  -- where k is an arbitrary constant
  [ rule l r' | ord /= Just LT && ord /= Just EQ ] ++
  [ rule r l' | ord /= Just GT && ord /= Just EQ ] ++
  [ rule l l' | not (null ls), ord /= Just GT ] ++
  [ rule r r' | not (null rs), ord /= Just LT ]
  where
    ord = orientTerms l' r'
    l' = erase ls l
    r' = erase rs r
    ls = usort (vars l) \\ usort (vars r)
    rs = usort (vars r) \\ usort (vars l)

    erase [] t = t
    erase xs t = subst sub t
      where
        sub = fromMaybe __ $ flattenSubst [(x, minimalTerm) | x <- xs]

rule :: Function f => Term f -> Term f -> Rule f
rule t u = Rule o t u
  where
    o | lessEq u t =
        case unify t u of
          Nothing -> Oriented
          Just sub
            | allSubst (\_ (Cons t Empty) -> isMinimal t) sub ->
              WeaklyOriented (map (build . var . fst) (listSubst sub))
            | otherwise -> Unoriented
      | lessEq t u = ERROR("wrongly-oriented rule")
      | not (null (usort (vars u) \\ usort (vars t))) =
        ERROR("unbound variables in rule")
      | Just ts <- evalStateT (makePermutative t u) [],
        permutativeOK t u ts =
        Permutative ts
      | otherwise = Unoriented

    permutativeOK _ _ [] = True
    permutativeOK t u ((Var x, Var y):xs) =
      lessIn model u t == Just Strict &&
      permutativeOK t' u' xs
      where
        model = modelFromOrder [Variable y, Variable x]
        sub x' = if x == x' then var y else var x'
        t' = subst sub t
        u' = subst sub u

    makePermutative t u = do
      msub <- gets flattenSubst
      sub  <- lift msub
      aux (subst sub t) (subst sub u)
        where
          aux (Var x) (Var y)
            | x == y = return []
            | otherwise = do
              modify ((x, build $ var y):)
              return [(build $ var x, build $ var y)]

          aux (Fun f ts) (Fun g us)
            | f == g =
              fmap concat (zipWithM makePermutative (fromTermList ts) (fromTermList us))

          aux _ _ = mzero

bothSides :: (Term f -> Term f') -> Equation f -> Equation f'
bothSides f (t :=: u) = f t :=: f u

trivial :: Eq f => Equation f -> Bool
trivial (t :=: u) = t == u

--------------------------------------------------------------------------------
-- Rewriting.
--------------------------------------------------------------------------------

type Strategy f = Term f -> [Reduction f]

data Reduction f =
    Step (Rule f) (Subst f)
  | Trans (Reduction f) (Reduction f)
  | Parallel [(Int, Reduction f)] (Term f)
  deriving Show

result :: Reduction f -> Term f
result (Parallel [] t) = t
result (Trans _ p) = result p
result t = build (emitReduction t)
  where
    emitReduction (Step r sub) = Term.subst sub (rhs r)
    emitReduction (Trans _ p) = emitReduction p
    emitReduction (Parallel ps t) = emitParallel 0 ps (singleton t)

    emitParallel !_ _ _ | False = __
    emitParallel _ _ Empty = mempty
    emitParallel _ [] t = builder t
    emitParallel n ((m, _):_) t  | m >= n + lenList t = builder t
    emitParallel n ps@((m, _):_) (Cons t u) | m >= n + len t =
      builder t `mappend` emitParallel (n + len t) ps u
    emitParallel n ((m, _):ps) t | m < n = emitParallel n ps t
    emitParallel n ((m, p):ps) (Cons t u) | m == n =
      emitReduction p `mappend` emitParallel (n + len t) ps u
    emitParallel n ps (Cons (Var x) u) =
      var x `mappend` emitParallel (n + 1) ps u
    emitParallel n ps (Cons (Fun f t) u) =
      fun f (emitParallel (n+1) ps t) `mappend`
      emitParallel (n + 1 + lenList t) ps u

instance (Numbered f, PrettyTerm f) => Pretty (Reduction f) where
  pPrint = pPrintReduction

pPrintReduction :: (Numbered f, PrettyTerm f) => Reduction f -> Doc
pPrintReduction p =
  case flatten p of
    [p] -> pp p
    ps -> pPrint (map pp ps)
  where
    flatten (Trans p q) = flatten p ++ flatten q
    flatten p = [p]

    pp p = sep [pp0 p, nest 2 (text "giving" <+> pPrint (result p))]
    pp0 (Step rule sub) =
      sep [pPrint rule,
           nest 2 (text "at" <+> pPrint sub)]
    pp0 (Parallel [] _) = text "refl"
    pp0 (Parallel [(0, p)] _) = pp0 p
    pp0 (Parallel ps _) =
      sep (punctuate (text " and")
        [hang (pPrint n <+> text "->") 2 (pPrint p) | (n, p) <- ps])

steps :: Reduction f -> [(Rule f, Subst f)]
steps r = aux r []
  where
    aux (Step r sub) = ((r, sub):)
    aux (Trans p q) = aux p . aux q
    aux (Parallel ps _) = foldr (.) id (map (aux . snd) ps)

anywhere1 :: (Numbered f, PrettyTerm f) => Strategy f -> Reduction f -> Maybe (Reduction f)
anywhere1 strat p = aux [] 0 (singleton t) p t
  where
    aux _ !_ !_ _ !_ | False = __
    aux [] _ Empty _ _ = Nothing
    aux ps _ Empty p t = Just (p `Trans` Parallel (reverse ps) t)
    aux ps n (Cons (Var _) t) p u = aux ps (n+1) t p u
    aux ps n (Cons t u) p v | q:_ <- strat t =
      aux ((n, q):ps) (n+len t) u p v
    aux ps n (ConsSym (Fun _ _) t) p u =
      aux ps (n+1) t p u

    t = result p

normaliseWith :: (Numbered f, PrettyTerm f) => Strategy f -> Term f -> Reduction f
normaliseWith strat t = aux 0 (Parallel [] t)
  where
    aux 1000 p =
      ERROR("Possibly nonterminating rewrite:\n" ++
            prettyShow p)
    aux n p =
      case anywhere1 strat p of
        Nothing -> p
        Just q -> aux (n+1) q

normalForms :: Function f => Strategy f -> [Term f] -> Set (Term f)
normalForms strat ts = go Set.empty Set.empty ts
  where
    go _ norm [] = norm
    go dead norm (t:ts)
      | t `Set.member` dead = go dead norm ts
      | t `Set.member` norm = go dead norm ts
      | null us = go dead (Set.insert t norm) ts
      | otherwise =
        go (Set.insert t dead) norm (us ++ ts)
      where
        us = map result (anywhere strat t)

anywhere :: Strategy f -> Strategy f
anywhere strat t = aux 0 (singleton t)
  where
    aux !_ Empty = []
    aux n (Cons Var{} u) = aux (n+1) u
    aux n (ConsSym u v) =
      [Parallel [(n,p)] t | p <- strat u] ++ aux (n+1) v

nested :: Strategy f -> Strategy f
nested strat t = [Parallel [(1,p)] t | p <- aux 0 (children t)]
  where
    aux !_ Empty = []
    aux n (Cons Var{} u) = aux (n+1) u
    aux n (Cons u v) =
      [Parallel [(n,p)] t | p <- strat u] ++ aux (n+len t) v

{-# INLINE rewrite #-}
rewrite :: Function f => String -> (Rule f -> Subst f -> Bool) -> Frozen (Rule f) -> Strategy f
rewrite _phase p rules t = do
  Index.Match rule sub <- Index.matches t rules
  guard (p rule sub)
  return (Step rule sub)

tryRule :: Function f => (Rule f -> Subst f -> Bool) -> Rule f -> Strategy f
tryRule p rule t = do
  sub <- maybeToList (match (lhs rule) t)
  guard (p rule sub)
  return (Step rule sub)

simplifies :: Function f => Rule f -> Subst f -> Bool
simplifies (Rule Oriented _ _) _ = True
simplifies (Rule (WeaklyOriented ts) _ _) sub =
  or [ not (isMinimal t) | t <- subst sub ts ]
simplifies (Rule (Permutative _) _ _) _ = False
simplifies (Rule Unoriented _ _) _ = False

reducesWith :: Function f => (Term f -> Term f -> Bool) -> Rule f -> Subst f -> Bool
reducesWith _ (Rule Oriented _ _) _ = True
reducesWith _ (Rule (WeaklyOriented ts) _ _) sub =
  or [ not (isMinimal t) | t <- subst sub ts ]
reducesWith p (Rule (Permutative ts) _ _) sub =
  aux ts
  where
    aux [] = False
    aux ((t, u):ts)
      | t' == u' = aux ts
      | otherwise = p u' t'
      where
        t' = subst sub t
        u' = subst sub u
reducesWith p (Rule Unoriented t u) sub =
  p u' t' && u' /= t'
  where
    t' = subst sub t
    u' = subst sub u

reduces :: Function f => Rule f -> Subst f -> Bool
reduces rule = reducesWith lessEq rule

reducesInModel :: Function f => Model f -> Rule f -> Subst f -> Bool
reducesInModel cond rule = reducesWith (\t u -> isJust (lessIn cond t u)) rule

reducesSkolem :: Function f => Rule f -> Subst f -> Bool
reducesSkolem = reducesWith (\t u -> lessEq (subst skolemise t) (subst skolemise u))
  where
    skolemise = con . skolem

reducesSub :: Function f => Term f -> Rule f -> Subst f -> Bool
reducesSub top rule sub =
  reducesSkolem rule sub && lessEq u top && isNothing (unify u top)
  where
    u = subst sub (rhs rule)