-- | An implementation of Knuth-Bendix ordering.

{-# LANGUAGE PatternGuards, BangPatterns #-}
module Twee.KBO(lessEq, lessIn, lessEqSkolem, Sized(..), Weighted(..)) where

import Twee.Base hiding (lessEq, lessIn, lessEqSkolem)
import Twee.Equation
import Twee.Constraints hiding (lessEq, lessIn, lessEqSkolem)
import qualified Data.Map.Strict as Map
import Data.Map.Strict(Map)
import Data.Maybe
import Control.Monad
import Twee.Utils

lessEqSkolem :: (Function f, Sized f, Weighted f) => Term f -> Term f -> Bool
lessEqSkolem :: forall f.
(Function f, Sized f, Weighted f) =>
Term f -> Term f -> Bool
lessEqSkolem !Term f
t !Term f
u
  | Integer
m forall a. Ord a => a -> a -> Bool
< Integer
n = Bool
True
  | Integer
m forall a. Ord a => a -> a -> Bool
> Integer
n = Bool
False
  where
    m :: Integer
m = forall a. Sized a => a -> Integer
size Term f
t
    n :: Integer
n = forall a. Sized a => a -> Integer
size Term f
u
lessEqSkolem (App Fun f
x TermList f
Empty) Term f
_
  | Fun f
x forall a. Eq a => a -> a -> Bool
== forall f. Minimal f => Fun f
minimal = Bool
True
lessEqSkolem Term f
_ (App Fun f
x TermList f
Empty)
  | Fun f
x forall a. Eq a => a -> a -> Bool
== forall f. Minimal f => Fun f
minimal = Bool
False
lessEqSkolem (Var Var
x) (Var Var
y) = Var
x forall a. Ord a => a -> a -> Bool
<= Var
y
lessEqSkolem (Var Var
_) Term f
_ = Bool
True
lessEqSkolem Term f
_ (Var Var
_) = Bool
False
lessEqSkolem (App (F Int
_ f
f) TermList f
ts) (App (F Int
_ f
g) TermList f
us) =
  case forall a. Ord a => a -> a -> Ordering
compare f
f f
g of
    Ordering
LT -> Bool
True
    Ordering
GT -> Bool
False
    Ordering
EQ ->
      let loop :: TermList f -> TermList f -> Bool
loop TermList f
Empty TermList f
Empty = Bool
True
          loop (Cons Term f
t TermList f
ts) (Cons Term f
u TermList f
us)
            | Term f
t forall a. Eq a => a -> a -> Bool
== Term f
u = TermList f -> TermList f -> Bool
loop TermList f
ts TermList f
us
            | Bool
otherwise = forall f.
(Function f, Sized f, Weighted f) =>
Term f -> Term f -> Bool
lessEqSkolem Term f
t Term f
u
      in forall {f}.
(Ordered f, Minimal f, PrettyTerm f, EqualsBonus f, Labelled f,
 Sized f, Weighted f) =>
TermList f -> TermList f -> Bool
loop TermList f
ts TermList f
us

-- | Check if one term is less than another in KBO.
lessEq :: (Function f, Sized f, Weighted f) => Term f -> Term f -> Bool
lessEq :: forall f.
(Function f, Sized f, Weighted f) =>
Term f -> Term f -> Bool
lessEq (App Fun f
f TermList f
Empty) Term f
_ | Fun f
f forall a. Eq a => a -> a -> Bool
== forall f. Minimal f => Fun f
minimal = Bool
True
lessEq (Var Var
x) (Var Var
y) | Var
x forall a. Eq a => a -> a -> Bool
== Var
y = Bool
True
lessEq Term f
_ (Var Var
_) = Bool
False
lessEq (Var Var
x) Term f
t = Var
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a. Symbolic a => a -> [Var]
vars Term f
t
lessEq t :: Term f
t@(App Fun f
f TermList f
ts) u :: Term f
u@(App Fun f
g TermList f
us) =
  (Integer
st forall a. Ord a => a -> a -> Bool
< Integer
su Bool -> Bool -> Bool
||
   (Integer
st forall a. Eq a => a -> a -> Bool
== Integer
su Bool -> Bool -> Bool
&& Fun f
f forall f. (Labelled f, Ord f) => Fun f -> Fun f -> Bool
<< Fun f
g) Bool -> Bool -> Bool
||
   (Integer
st forall a. Eq a => a -> a -> Bool
== Integer
su Bool -> Bool -> Bool
&& Fun f
f forall a. Eq a => a -> a -> Bool
== Fun f
g Bool -> Bool -> Bool
&& forall {f}.
(Ordered f, Minimal f, PrettyTerm f, EqualsBonus f, Labelled f,
 Sized f, Weighted f) =>
TermList f -> TermList f -> Bool
lexLess TermList f
ts TermList f
us)) Bool -> Bool -> Bool
&&
  [(Var, Integer)]
xs forall {a} {a}. (Ord a, Ord a) => [(a, a)] -> [(a, a)] -> Bool
`lessVars` [(Var, Integer)]
ys
  where
    lexLess :: TermList f -> TermList f -> Bool
lexLess TermList f
Empty TermList f
Empty = Bool
True
    lexLess (Cons Term f
t TermList f
ts) (Cons Term f
u TermList f
us)
      | Term f
t forall a. Eq a => a -> a -> Bool
== Term f
u = TermList f -> TermList f -> Bool
lexLess TermList f
ts TermList f
us
      | Bool
otherwise =
        forall f.
(Function f, Sized f, Weighted f) =>
Term f -> Term f -> Bool
lessEq Term f
t Term f
u Bool -> Bool -> Bool
&&
        case forall f. Term f -> Term f -> Maybe (Subst f)
unify Term f
t Term f
u of
          Maybe (Subst f)
Nothing -> Bool
True
          Just Subst f
sub
            | Bool -> Bool
not (forall f. (Var -> TermList f -> Bool) -> Subst f -> Bool
allSubst (\Var
_ (Cons Term f
t TermList f
Empty) -> forall f. Minimal f => Term f -> Bool
isMinimal Term f
t) Subst f
sub) -> forall a. HasCallStack => [Char] -> a
error [Char]
"weird term inequality"
            | Bool
otherwise -> TermList f -> TermList f -> Bool
lexLess (forall a s.
(Symbolic a, Substitution s, SubstFun s ~ ConstantOf a) =>
s -> a -> a
subst Subst f
sub TermList f
ts) (forall a s.
(Symbolic a, Substitution s, SubstFun s ~ ConstantOf a) =>
s -> a -> a
subst Subst f
sub TermList f
us)
    lexLess TermList f
_ TermList f
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"incorrect function arity"
    xs :: [(Var, Integer)]
xs = forall f. (Weighted f, Labelled f) => Term f -> [(Var, Integer)]
weightedVars Term f
t
    ys :: [(Var, Integer)]
ys = forall f. (Weighted f, Labelled f) => Term f -> [(Var, Integer)]
weightedVars Term f
u
    st :: Integer
st = forall a. Sized a => a -> Integer
size Term f
t
    su :: Integer
su = forall a. Sized a => a -> Integer
size Term f
u

    [] lessVars :: [(a, a)] -> [(a, a)] -> Bool
`lessVars` [(a, a)]
_ = Bool
True
    ((a
x,a
k1):[(a, a)]
xs) `lessVars` ((a
y,a
k2):[(a, a)]
ys)
      | a
x forall a. Eq a => a -> a -> Bool
== a
y = a
k1 forall a. Ord a => a -> a -> Bool
<= a
k2 Bool -> Bool -> Bool
&& [(a, a)]
xs [(a, a)] -> [(a, a)] -> Bool
`lessVars` [(a, a)]
ys
      | a
x forall a. Ord a => a -> a -> Bool
> a
y  = ((a
x,a
k1)forall a. a -> [a] -> [a]
:[(a, a)]
xs) [(a, a)] -> [(a, a)] -> Bool
`lessVars` [(a, a)]
ys
    [(a, a)]
_ `lessVars` [(a, a)]
_ = Bool
False

-- | Check if one term is less than another in a given model.

-- See "notes/kbo under assumptions" for how this works.

lessIn :: (Function f, Sized f, Weighted f) => Model f -> Term f -> Term f -> Maybe Strictness
lessIn :: forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
lessIn Model f
model Term f
t Term f
u =
  case forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
sizeLessIn Model f
model Term f
t Term f
u of
    Maybe Strictness
Nothing -> forall a. Maybe a
Nothing
    Just Strictness
Strict -> forall a. a -> Maybe a
Just Strictness
Strict
    Just Strictness
Nonstrict -> forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
lexLessIn Model f
model Term f
t Term f
u

sizeLessIn :: (Function f, Sized f, Weighted f) => Model f -> Term f -> Term f -> Maybe Strictness
sizeLessIn :: forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
sizeLessIn Model f
model Term f
t Term f
u =
  case forall f.
(Function f, Sized f) =>
Model f -> Map Var Integer -> Maybe Integer
minimumIn Model f
model Map Var Integer
m of
    Just Integer
l
      | Integer
l forall a. Ord a => a -> a -> Bool
>  -Integer
k -> forall a. a -> Maybe a
Just Strictness
Strict
      | Integer
l forall a. Eq a => a -> a -> Bool
== -Integer
k -> forall a. a -> Maybe a
Just Strictness
Nonstrict
    Maybe Integer
_ -> forall a. Maybe a
Nothing
  where
    (Integer
k, Map Var Integer
m) =
      forall {f}.
(Weighted f, Labelled f, Sized f) =>
Integer
-> Term f
-> (Integer, Map Var Integer)
-> (Integer, Map Var Integer)
add Integer
1 Term f
u (forall {f}.
(Weighted f, Labelled f, Sized f) =>
Integer
-> Term f
-> (Integer, Map Var Integer)
-> (Integer, Map Var Integer)
add (-Integer
1) Term f
t (Integer
0, forall k a. Map k a
Map.empty))

    add :: Integer
-> Term f
-> (Integer, Map Var Integer)
-> (Integer, Map Var Integer)
add Integer
a (App Fun f
f TermList f
ts) (Integer
k, Map Var Integer
m) =
      forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Integer
-> Term f
-> (Integer, Map Var Integer)
-> (Integer, Map Var Integer)
add (Integer
a forall a. Num a => a -> a -> a
* forall f. Weighted f => f -> Integer
argWeight Fun f
f)) (Integer
k forall a. Num a => a -> a -> a
+ Integer
a forall a. Num a => a -> a -> a
* forall a. Sized a => a -> Integer
size Fun f
f, Map Var Integer
m) (forall f. TermList f -> [Term f]
unpack TermList f
ts)
    add Integer
a (Var Var
x) (Integer
k, Map Var Integer
m) = (Integer
k, forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith forall a. Num a => a -> a -> a
(+) Var
x Integer
a Map Var Integer
m)

minimumIn :: (Function f, Sized f) => Model f -> Map Var Integer -> Maybe Integer
minimumIn :: forall f.
(Function f, Sized f) =>
Model f -> Map Var Integer -> Maybe Integer
minimumIn Model f
model Map Var Integer
t =
  forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 forall a. Num a => a -> a -> a
(+)
    (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a} {a}.
(Sized a, Sized a) =>
(a, [Var], Maybe a) -> Maybe Integer
minGroup (forall f.
(Minimal f, Ord f) =>
Model f -> [(Fun f, [Var], Maybe (Fun f))]
varGroups Model f
model)))
    (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}. (Num a, Ord a) => (Var, a) -> Maybe a
minOrphan (forall k a. Map k a -> [(k, a)]
Map.toList Map Var Integer
t)))
  where
    minGroup :: (a, [Var], Maybe a) -> Maybe Integer
minGroup (a
lo, [Var]
xs, Maybe a
mhi)
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Ord a => a -> a -> Bool
>= Integer
0) [Integer]
sums = forall a. a -> Maybe a
Just (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer]
coeffs forall a. Num a => a -> a -> a
* forall a. Sized a => a -> Integer
size a
lo)
      | Bool
otherwise =
        case Maybe a
mhi of
          Maybe a
Nothing -> forall a. Maybe a
Nothing
          Just a
hi ->
            let coeff :: Integer
coeff = forall a. Num a => a -> a
negate (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [Integer]
coeffs) in
            forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer]
coeffs forall a. Num a => a -> a -> a
* forall a. Sized a => a -> Integer
size a
lo forall a. Num a => a -> a -> a
+
              Integer
coeff forall a. Num a => a -> a -> a
* (forall a. Sized a => a -> Integer
size a
lo forall a. Num a => a -> a -> a
- forall a. Sized a => a -> Integer
size a
hi)
      where
        coeffs :: [Integer]
coeffs = forall a b. (a -> b) -> [a] -> [b]
map (\Var
x -> forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Integer
0 Var
x Map Var Integer
t) [Var]
xs
        sums :: [Integer]
sums = forall a. (a -> a -> a) -> [a] -> [a]
scanr1 forall a. Num a => a -> a -> a
(+) [Integer]
coeffs

    minOrphan :: (Var, a) -> Maybe a
minOrphan (Var
x, a
k)
      | forall f. (Minimal f, Ord f) => Model f -> Var -> Bool
varInModel Model f
model Var
x = forall a. a -> Maybe a
Just a
0
      | a
k forall a. Ord a => a -> a -> Bool
< a
0 = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just a
k

lexLessIn :: (Function f, Sized f, Weighted f) => Model f -> Term f -> Term f -> Maybe Strictness
lexLessIn :: forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
lexLessIn Model f
_ Term f
t Term f
u | Term f
t forall a. Eq a => a -> a -> Bool
== Term f
u = forall a. a -> Maybe a
Just Strictness
Nonstrict
lexLessIn Model f
cond Term f
t Term f
u
  | Just Atom f
a <- forall f. Term f -> Maybe (Atom f)
fromTerm Term f
t,
    Just Atom f
b <- forall f. Term f -> Maybe (Atom f)
fromTerm Term f
u,
    Just Strictness
x <- forall f.
(Minimal f, Ordered f, Labelled f) =>
Model f -> Atom f -> Atom f -> Maybe Strictness
lessEqInModel Model f
cond Atom f
a Atom f
b = forall a. a -> Maybe a
Just Strictness
x
  | Just Atom f
a <- forall f. Term f -> Maybe (Atom f)
fromTerm Term f
t,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall a. Maybe a -> Bool
isJust
      [ forall f.
(Minimal f, Ordered f, Labelled f) =>
Model f -> Atom f -> Atom f -> Maybe Strictness
lessEqInModel Model f
cond Atom f
a Atom f
b
      | Term f
v <- forall f. Term f -> [Term f]
properSubterms Term f
u, Just Atom f
b <- [forall f. Term f -> Maybe (Atom f)
fromTerm Term f
v]] =
        forall a. a -> Maybe a
Just Strictness
Strict
lexLessIn Model f
cond (App Fun f
f TermList f
ts) (App Fun f
g TermList f
us)
  | Fun f
f forall a. Eq a => a -> a -> Bool
== Fun f
g = TermList f -> TermList f -> Maybe Strictness
loop TermList f
ts TermList f
us
  | Fun f
f forall f. (Labelled f, Ord f) => Fun f -> Fun f -> Bool
<< Fun f
g = forall a. a -> Maybe a
Just Strictness
Strict
  | Bool
otherwise = forall a. Maybe a
Nothing
  where
    loop :: TermList f -> TermList f -> Maybe Strictness
loop TermList f
Empty TermList f
Empty = forall a. a -> Maybe a
Just Strictness
Nonstrict
    loop (Cons Term f
t TermList f
ts) (Cons Term f
u TermList f
us)
      | Term f
t forall a. Eq a => a -> a -> Bool
== Term f
u = TermList f -> TermList f -> Maybe Strictness
loop TermList f
ts TermList f
us
      | Bool
otherwise =
        case forall f.
(Function f, Sized f, Weighted f) =>
Model f -> Term f -> Term f -> Maybe Strictness
lessIn Model f
cond Term f
t Term f
u of
          Maybe Strictness
Nothing -> forall a. Maybe a
Nothing
          Just Strictness
Strict -> forall a. a -> Maybe a
Just Strictness
Strict
          Just Strictness
Nonstrict ->
            let Just Subst f
sub = forall f. Term f -> Term f -> Maybe (Subst f)
unify Term f
t Term f
u in
            TermList f -> TermList f -> Maybe Strictness
loop (forall a s.
(Symbolic a, Substitution s, SubstFun s ~ ConstantOf a) =>
s -> a -> a
subst Subst f
sub TermList f
ts) (forall a s.
(Symbolic a, Substitution s, SubstFun s ~ ConstantOf a) =>
s -> a -> a
subst Subst f
sub TermList f
us)
    loop TermList f
_ TermList f
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"incorrect function arity"
lexLessIn Model f
_ Term f
t Term f
_ | forall f. Minimal f => Term f -> Bool
isMinimal Term f
t = forall a. a -> Maybe a
Just Strictness
Nonstrict
lexLessIn Model f
_ Term f
_ Term f
_ = forall a. Maybe a
Nothing

class Sized a where
  -- | Compute the size.
  size  :: a -> Integer

class Weighted f where
  argWeight :: f -> Integer

instance (Weighted f, Labelled f) => Weighted (Fun f) where
  argWeight :: Fun f -> Integer
argWeight = forall f. Weighted f => f -> Integer
argWeight forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Labelled f => Fun f -> f
fun_value

weightedVars :: (Weighted f, Labelled f) => Term f -> [(Var, Integer)]
weightedVars :: forall f. (Weighted f, Labelled f) => Term f -> [(Var, Integer)]
weightedVars Term f
t = forall a b c. Ord a => ([b] -> c) -> [(a, b)] -> [(a, c)]
collate forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall {f}.
(Weighted f, Labelled f) =>
Integer -> Term f -> [(Var, Integer)]
loop Integer
1 Term f
t)
  where
    loop :: Integer -> Term f -> [(Var, Integer)]
loop Integer
k (Var Var
x) = [(Var
x, Integer
k)]
    loop Integer
k (App Fun f
f TermList f
ts) =
      forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Integer -> Term f -> [(Var, Integer)]
loop (Integer
k forall a. Num a => a -> a -> a
* forall f. Weighted f => f -> Integer
argWeight Fun f
f)) (forall f. TermList f -> [Term f]
unpack TermList f
ts)

instance (Labelled f, Sized f) => Sized (Fun f) where
  size :: Fun f -> Integer
size = forall a. Sized a => a -> Integer
size forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Labelled f => Fun f -> f
fun_value

instance (Labelled f, Sized f, Weighted f) => Sized (TermList f) where
  size :: TermList f -> Integer
size = forall {f}.
(Labelled f, Sized f, Weighted f) =>
Integer -> TermList f -> Integer
aux Integer
0
    where
      aux :: Integer -> TermList f -> Integer
aux Integer
n TermList f
Empty = Integer
n
      aux Integer
n (Cons (App Fun f
f TermList f
t) TermList f
u) =
        Integer -> TermList f -> Integer
aux (Integer
n forall a. Num a => a -> a -> a
+ forall a. Sized a => a -> Integer
size Fun f
f forall a. Num a => a -> a -> a
+ forall f. Weighted f => f -> Integer
argWeight Fun f
f forall a. Num a => a -> a -> a
* forall a. Sized a => a -> Integer
size TermList f
t) TermList f
u
      aux Integer
n (Cons (Var Var
_) TermList f
t) = Integer -> TermList f -> Integer
aux (Integer
nforall a. Num a => a -> a -> a
+Integer
1) TermList f
t

instance (Labelled f, Sized f, Weighted f) => Sized (Term f) where
  size :: Term f -> Integer
size = forall a. Sized a => a -> Integer
size forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Term f -> TermList f
singleton

instance (Labelled f, Sized f, Weighted f) => Sized (Equation f) where
  size :: Equation f -> Integer
size (Term f
x :=: Term f
y) = forall a. Sized a => a -> Integer
size Term f
x forall a. Num a => a -> a -> a
+ forall a. Sized a => a -> Integer
size Term f
y