{-# LANGUAGE PatternGuards, MagicHash, MultiWayIf, TypeOperators #-}
module Cryptol.TypeCheck.Solver.Numeric
  ( cryIsEqual, cryIsNotEqual, cryIsGeq, cryIsPrime, primeTable
  ) where

import           Control.Applicative(Alternative(..))
import           Control.Monad (guard,mzero)
import qualified Control.Monad.Fail as Fail
import           Data.List (sortBy)
import           Data.MemoTrie

import Math.NumberTheory.Primes.Testing (isPrime)

import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType as Simp

{- Convention for comments:

  K1, K2 ...          Concrete constants
  s1, s2, t1, t2 ...  Arbitrary type expressions
  a, b, c ...         Type variables

-}


-- | Try to solve @t1 = t2@
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual Ctxt
ctxt Type
t1 Type
t2 =
  forall a. a -> Match a -> a
matchDefault Solved
Unsolved forall a b. (a -> b) -> a -> b
$
        ((Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin forall a. Eq a => a -> a -> Bool
(==) Type
t1 Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type Nat'
aNat' Type
t1 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ctxt -> Type -> Nat' -> Match Solved
tryEqK Ctxt
ctxt Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type Nat'
aNat' Type
t2 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ctxt -> Type -> Nat' -> Match Solved
tryEqK Ctxt
ctxt Type
t1)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type TVar
aTVar Type
t1 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Type -> TVar -> Match Solved
tryEqVar Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type TVar
aTVar Type
t2 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Type -> TVar -> Match Solved
tryEqVar Type
t1)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Type
t1 forall a. Eq a => a -> a -> Bool
== Type
t2) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf []))
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryEqMin Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryEqMin Type
t2 Type
t1
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryEqMins Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryEqMins Type
t2 Type
t1
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryEqMulConst Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Ctxt -> Type -> Type -> Match Solved
tryEqAddInf Ctxt
ctxt Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryAddConst Type -> Type -> Type
(=#=) Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Ctxt -> (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryCancelVar Ctxt
ctxt Type -> Type -> Type
(=#=) Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryLinearSolution Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryLinearSolution Type
t2 Type
t1

-- | Try to solve @t1 /= t2@
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual Ctxt
_i Type
t1 Type
t2 = forall a. a -> Match a -> a
matchDefault Solved
Unsolved ((Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin forall a. Eq a => a -> a -> Bool
(/=) Type
t1 Type
t2)

-- | Try to solve @t1 >= t2@
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq Ctxt
i Type
t1 Type
t2 =
  forall a. a -> Match a -> a
matchDefault Solved
Unsolved forall a b. (a -> b) -> a -> b
$
        ((Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin forall a. Ord a => a -> a -> Bool
(>=) Type
t1 Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type Nat'
aNat' Type
t1 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan Ctxt
i Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type Nat'
aNat' Type
t2 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK Ctxt
i Type
t1)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Pat Type TVar
aTVar Type
t2 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar Ctxt
i Type
t1)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub Ctxt
i Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Ctxt -> Type -> Type -> Match Solved
geqByInterval Ctxt
i Type
t1 Type
t2)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Type
t1 forall a. Eq a => a -> a -> Bool
== Type
t2) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf []))
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryAddConst Type -> Type -> Type
(>==) Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Ctxt -> (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryCancelVar Ctxt
i Type -> Type -> Type
(>==) Type
t1 Type
t2
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
tryMinIsGeq Type
t1 Type
t2
    -- XXX: k >= width e
    -- XXX: width e >= k


  -- XXX: max t 10 >= 2 --> True
  -- XXX: max t 2 >= 10 --> a >= 10

{-# NOINLINE primeTable #-}
primeTable :: Integer :->: Bool
primeTable :: Integer :->: Bool
primeTable = forall a b. HasTrie a => (a -> b) -> a :->: b
trie Integer -> Bool
isPrime

cryIsPrime :: Ctxt -> Type -> Solved
cryIsPrime :: Ctxt -> Type -> Solved
cryIsPrime Ctxt
_varInfo Type
ty =
  case Type -> Type
tNoUser Type
ty of

    TCon (TC TC
tc) []
      | TCNum Integer
n <- TC
tc ->
          if forall a b. HasTrie a => (a :->: b) -> a -> b
untrie Integer :->: Bool
primeTable Integer
n then
            [Type] -> Solved
SolvedIf []
          else
            Solved
Unsolvable

      | TC
TCInf <- TC
tc -> Solved
Unsolvable

    Type
_ -> Solved
Unsolved


-- | Try to solve something by evaluation.
pBin :: (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin :: (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin Nat' -> Nat' -> Bool
p Type
t1 Type
t2
  | Just Type
_ <- Type -> Maybe Type
tIsError Type
t1 = forall (f :: * -> *) a. Applicative f => a -> f a
pure Solved
Unsolvable
  | Just Type
_ <- Type -> Maybe Type
tIsError Type
t2 = forall (f :: * -> *) a. Applicative f => a -> f a
pure Solved
Unsolvable
  | Bool
otherwise = do Nat'
x <- Pat Type Nat'
aNat' Type
t1
                   Nat'
y <- Pat Type Nat'
aNat' Type
t2
                   forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Nat' -> Nat' -> Bool
p Nat'
x Nat'
y
                              then [Type] -> Solved
SolvedIf []
                              else Solved
Unsolvable


--------------------------------------------------------------------------------
-- GEQ

-- | Try to solve @K >= t@
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan Ctxt
_ Type
_ Nat'
Inf = forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [])
tryGeqKThan Ctxt
_ Type
ty (Nat Integer
n) =

  -- K1 >= K2 * t
  do (Type
a,Type
b) <- Pat Type (Type, Type)
aMul Type
ty
     Nat'
m     <- Pat Type Nat'
aNat' Type
a
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf
            forall a b. (a -> b) -> a -> b
$ case Nat'
m of
                Nat'
Inf   -> [ Type
b Type -> Type -> Type
=#= Type
tZero ]
                Nat Integer
0 -> []
                Nat Integer
k -> [ forall a. Integral a => a -> Type
tNum (forall a. Integral a => a -> a -> a
div Integer
n Integer
k) Type -> Type -> Type
>== Type
b ]

-- | Try to solve @t >= K@
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK Ctxt
_ Type
t Nat'
Inf = forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ Type
t Type -> Type -> Type
=#= Type
tInf ])
tryGeqThanK Ctxt
_ Type
t (Nat Integer
k) =

  -- K1 + t >= K2
  do (Type
a,Type
b) <- Pat Type (Type, Type)
anAdd Type
t
     Integer
n     <- Pat Type Integer
aNat Type
a
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf forall a b. (a -> b) -> a -> b
$ if Integer
n forall a. Ord a => a -> a -> Bool
>= Integer
k
                            then []
                            else [ Type
b Type -> Type -> Type
>== forall a. Integral a => a -> Type
tNum (Integer
k forall a. Num a => a -> a -> a
- Integer
n) ]
  -- XXX: K1 ^^ n >= K2


tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub Ctxt
_ Type
x Type
y =

  -- t1 >= t1 - t2
  do (Type
a,Type
_) <- Pat Type (Type, Type)
(|-|) Type
y
     forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Type
x forall a. Eq a => a -> a -> Bool
== Type
a)
     forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [])

tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar Ctxt
_ctxt Type
ty TVar
x =
  -- (t + a) >= a
  do (Type
a,Type
b) <- Pat Type (Type, Type)
anAdd Type
ty
     let check :: Type -> Match Solved
check Type
y = do TVar
x' <- Pat Type TVar
aTVar Type
y
                      forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TVar
x forall a. Eq a => a -> a -> Bool
== TVar
x')
                      forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [])
     Type -> Match Solved
check Type
a forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Match Solved
check Type
b

-- | Try to prove GEQ by considering the known intervals for the given types.
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval Ctxt
ctxt Type
x Type
y =
  let ix :: Interval
ix = Map TVar Interval -> Type -> Interval
typeInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) Type
x
      iy :: Interval
iy = Map TVar Interval -> Type -> Interval
typeInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) Type
y
  in case (Interval -> Nat'
iLower Interval
ix, Interval -> Maybe Nat'
iUpper Interval
iy) of
       (Nat'
l,Just Nat'
n) | Nat'
l forall a. Ord a => a -> a -> Bool
>= Nat'
n -> forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [])
       (Nat', Maybe Nat')
_                   -> forall (m :: * -> *) a. MonadPlus m => m a
mzero

-- min K1 t >= K2 ~~> t >= K2, if K1 >= K2;  Err otherwise
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq Type
t1 Type
t2 =
  do (Type
a,Type
b) <- Pat Type (Type, Type)
aMin Type
t1
     Integer
k1    <- Pat Type Integer
aNat Type
a
     Integer
k2    <- Pat Type Integer
aNat Type
t2
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Integer
k1 forall a. Ord a => a -> a -> Bool
>= Integer
k2
               then [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
>== Type
t2 ]
               else Solved
Unsolvable

--------------------------------------------------------------------------------

-- | Cancel finite positive variables from both sides.
-- @(fin a, a >= 1) =>  a * t1 == a * t2 ~~~> t1 == t2@
-- @(fin a, a >= 1) =>  a * t1 >= a * t2 ~~~> t1 >= t2@
tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryCancelVar :: Ctxt -> (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryCancelVar Ctxt
ctxt Type -> Type -> Type
p Type
t1 Type
t2 =
  let lhs :: [(Type, Maybe TVar)]
lhs = Type -> [(Type, Maybe TVar)]
preproc Type
t1
      rhs :: [(Type, Maybe TVar)]
rhs = Type -> [(Type, Maybe TVar)]
preproc Type
t2
  in case forall {a}.
Ord a =>
[Type]
-> [Type] -> [(Type, Maybe a)] -> [(Type, Maybe a)] -> Maybe Solved
check [] [] [(Type, Maybe TVar)]
lhs [(Type, Maybe TVar)]
rhs of
       Maybe Solved
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
"tryCancelVar"
       Just Solved
x  -> forall (m :: * -> *) a. Monad m => a -> m a
return Solved
x


  where
  check :: [Type]
-> [Type] -> [(Type, Maybe a)] -> [(Type, Maybe a)] -> Maybe Solved
check [Type]
doneLHS [Type]
doneRHS lhs :: [(Type, Maybe a)]
lhs@((Type
a,Maybe a
mbA) : [(Type, Maybe a)]
moreLHS) rhs :: [(Type, Maybe a)]
rhs@((Type
b, Maybe a
mbB) : [(Type, Maybe a)]
moreRHS) =
    do a
x <- Maybe a
mbA
       a
y <- Maybe a
mbB
       case forall a. Ord a => a -> a -> Ordering
compare a
x a
y of
         Ordering
LT -> [Type]
-> [Type] -> [(Type, Maybe a)] -> [(Type, Maybe a)] -> Maybe Solved
check (Type
a forall a. a -> [a] -> [a]
: [Type]
doneLHS) [Type]
doneRHS [(Type, Maybe a)]
moreLHS [(Type, Maybe a)]
rhs
         Ordering
EQ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf [ Type -> Type -> Type
p ([Type] -> Type
term ([Type]
doneLHS forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Type, Maybe a)]
moreLHS))
                                     ([Type] -> Type
term ([Type]
doneRHS forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Type, Maybe a)]
moreRHS)) ]
         Ordering
GT -> [Type]
-> [Type] -> [(Type, Maybe a)] -> [(Type, Maybe a)] -> Maybe Solved
check [Type]
doneLHS (Type
b forall a. a -> [a] -> [a]
: [Type]
doneRHS) [(Type, Maybe a)]
lhs [(Type, Maybe a)]
moreRHS
  check [Type]
_ [Type]
_ [(Type, Maybe a)]
_ [(Type, Maybe a)]
_ = forall a. Maybe a
Nothing

  term :: [Type] -> Type
term [Type]
xs = case [Type]
xs of
              [] -> forall a. Integral a => a -> Type
tNum (Int
1::Int)
              [Type]
_  -> forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Type -> Type -> Type
tMul [Type]
xs

  preproc :: Type -> [(Type, Maybe TVar)]
preproc Type
t = let fs :: [Type]
fs = Type -> [Type] -> [Type]
splitMul Type
t []
              in forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy forall {a} {a} {a}.
Ord a =>
(a, Maybe a) -> (a, Maybe a) -> Ordering
cmpFact (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
fs (forall a b. (a -> b) -> [a] -> [b]
map Type -> Maybe TVar
cancelVar [Type]
fs))

  splitMul :: Type -> [Type] -> [Type]
splitMul Type
t [Type]
rest = case forall a. Match a -> Maybe a
matchMaybe (Pat Type (Type, Type)
aMul Type
t) of
                      Just (Type
a,Type
b) -> Type -> [Type] -> [Type]
splitMul Type
a (Type -> [Type] -> [Type]
splitMul Type
b [Type]
rest)
                      Maybe (Type, Type)
Nothing    -> Type
t forall a. a -> [a] -> [a]
: [Type]
rest

  cancelVar :: Type -> Maybe TVar
cancelVar Type
t = forall a. Match a -> Maybe a
matchMaybe forall a b. (a -> b) -> a -> b
$ do TVar
x <- Pat Type TVar
aTVar Type
t
                                forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Interval -> Bool
iIsPosFin (Map TVar Interval -> TVar -> Interval
tvarInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) TVar
x))
                                forall (m :: * -> *) a. Monad m => a -> m a
return TVar
x

  -- cancellable variables go first, sorted alphabetically
  cmpFact :: (a, Maybe a) -> (a, Maybe a) -> Ordering
cmpFact (a
_,Maybe a
mbA) (a
_,Maybe a
mbB) =
    case (Maybe a
mbA,Maybe a
mbB) of
      (Just a
x, Just a
y)  -> forall a. Ord a => a -> a -> Ordering
compare a
x a
y
      (Just a
_, Maybe a
Nothing) -> Ordering
LT
      (Maybe a
Nothing, Just a
_) -> Ordering
GT
      (Maybe a, Maybe a)
_                 -> Ordering
EQ



-- min t1 t2 = t1 ~> t1 <= t2
tryEqMin :: Type -> Type -> Match Solved
tryEqMin :: Type -> Type -> Match Solved
tryEqMin Type
x Type
y =
  do (Type
a,Type
b) <- Pat Type (Type, Type)
aMin Type
x
     let check :: Type -> Type -> m Solved
check Type
m1 Type
m2 = do forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Type
m1 forall a. Eq a => a -> a -> Bool
== Type
y)
                          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf [ Type
m2 Type -> Type -> Type
>== Type
m1 ]
     forall {m :: * -> *}.
(Monad m, Alternative m) =>
Type -> Type -> m Solved
check Type
a Type
b forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall {m :: * -> *}.
(Monad m, Alternative m) =>
Type -> Type -> m Solved
check Type
b Type
a


-- t1 == min (K + t1) t2 ~~> t1 == t2, if K >= 1
-- (also if (K + t1) is one term in a multi-way min)
tryEqMins :: Type -> Type -> Match Solved
tryEqMins :: Type -> Type -> Match Solved
tryEqMins Type
x Type
y =
  do (Type
a, Type
b) <- Pat Type (Type, Type)
aMin Type
y
     let ys :: [Type]
ys = Type -> [Type]
splitMin Type
a forall a. [a] -> [a] -> [a]
++ Type -> [Type]
splitMin Type
b
     let ys' :: [Type]
ys' = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
isGt) [Type]
ys
     let y' :: Type
y' = if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
ys' then Type
tInf else forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Type -> Type -> Type
Simp.tMin [Type]
ys'
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ys' forall a. Ord a => a -> a -> Bool
< forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ys
              then [Type] -> Solved
SolvedIf [Type
x Type -> Type -> Type
=#= Type
y']
              else Solved
Unsolved
  where
    splitMin :: Type -> [Type]
    splitMin :: Type -> [Type]
splitMin Type
ty =
      case forall a. Match a -> Maybe a
matchMaybe (Pat Type (Type, Type)
aMin Type
ty) of
        Just (Type
t1, Type
t2) -> Type -> [Type]
splitMin Type
t1 forall a. [a] -> [a] -> [a]
++ Type -> [Type]
splitMin Type
t2
        Maybe (Type, Type)
Nothing       -> [Type
ty]

    isGt :: Type -> Bool
    isGt :: Type -> Bool
isGt Type
t =
      case forall a. Match a -> Maybe a
matchMaybe (Type -> Match (Integer, Type)
asAddK Type
t) of
        Just (Integer
k, Type
t') -> Integer
k forall a. Ord a => a -> a -> Bool
> Integer
0 Bool -> Bool -> Bool
&& Type
t' forall a. Eq a => a -> a -> Bool
== Type
x
        Maybe (Integer, Type)
Nothing      -> Bool
False

    asAddK :: Type -> Match (Integer, Type)
    asAddK :: Type -> Match (Integer, Type)
asAddK Type
t =
      do (Type
t1, Type
t2) <- Pat Type (Type, Type)
anAdd Type
t
         Integer
k <- Pat Type Integer
aNat Type
t1
         forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
k, Type
t2)


tryEqVar :: Type -> TVar -> Match Solved
tryEqVar :: Type -> TVar -> Match Solved
tryEqVar Type
ty TVar
x =

  -- a = K + a --> x = inf
  (do (Integer
k,TVar
tv) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
anAdd, Pat Type Integer
aNat, Pat Type TVar
aTVar)
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TVar
tv forall a. Eq a => a -> a -> Bool
== TVar
x Bool -> Bool -> Bool
&& Integer
k forall a. Ord a => a -> a -> Bool
>= Integer
1)

      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf [ TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type
tInf ]
  )
  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>

  -- a = min (K + a) t --> a = t
  (do (Type
l,Type
r) <- Pat Type (Type, Type)
aMin Type
ty
      let check :: Type -> Type -> Match Solved
check Type
this Type
other =
            do (Nat'
k,TVar
x') <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
this (Pat Type (Type, Type)
anAdd, Pat Type Nat'
aNat', Pat Type TVar
aTVar)
               forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TVar
x forall a. Eq a => a -> a -> Bool
== TVar
x' Bool -> Bool -> Bool
&& Nat'
k forall a. Ord a => a -> a -> Bool
>= Integer -> Nat'
Nat Integer
1)
               forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf [ TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type
other ]
      Type -> Type -> Match Solved
check Type
l Type
r forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
check Type
r Type
l
  )
  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
  -- a = K + min t a
  (do (Integer
k,(Type
l,Type
r)) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
anAdd, Pat Type Integer
aNat, Pat Type (Type, Type)
aMin)
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Integer
k forall a. Ord a => a -> a -> Bool
>= Integer
1)
      let check :: Type -> Type -> Match Solved
check Type
a Type
b = do TVar
x' <- Pat Type TVar
aTVar Type
a
                         forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TVar
x' forall a. Eq a => a -> a -> Bool
== TVar
x)
                         forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type -> Type -> Type
tAdd (forall a. Integral a => a -> Type
tNum Integer
k) Type
b ])
      Type -> Type -> Match Solved
check Type
l Type
r forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
check Type
r Type
l
  )







-- e.g., 10 = t
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK Ctxt
ctxt Type
ty Nat'
lk =

  -- (t1 + t2 = inf, fin t1) ~~~> t2 = inf
  do forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Nat'
lk forall a. Eq a => a -> a -> Bool
== Nat'
Inf)
     (Type
a,Type
b) <- Pat Type (Type, Type)
anAdd Type
ty
     let check :: Type -> Type -> m Solved
check Type
x Type
y = do forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Interval -> Bool
iIsFin (Map TVar Interval -> Type -> Interval
typeInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) Type
x))
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Type] -> Solved
SolvedIf [ Type
y Type -> Type -> Type
=#= Type
tInf ]
     forall {m :: * -> *}.
(Monad m, Alternative m) =>
Type -> Type -> m Solved
check Type
a Type
b forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall {m :: * -> *}.
(Monad m, Alternative m) =>
Type -> Type -> m Solved
check Type
b Type
a
  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>

  -- (K1 + t = K2, K2 >= K1) ~~~> t = (K2 - K1)
  do (Nat'
rk, Type
b) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
anAdd, Pat Type Nat'
aNat', forall a. Pat a a
__)
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
       case Nat' -> Nat' -> Maybe Nat'
nSub Nat'
lk Nat'
rk of
         -- NOTE: (Inf - Inf) shouldn't be possible
         Maybe Nat'
Nothing -> Solved
Unsolvable

         Just Nat'
r -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= Nat' -> Type
tNat' Nat'
r ]
  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>

  -- (lk = t - rk) ~~> t = lk + rk
  do (Type
t,Nat'
rk) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
(|-|) , forall a. Pat a a
__, Pat Type Nat'
aNat')
     forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ Type
t Type -> Type -> Type
=#= Nat' -> Type
tNat' (Nat' -> Nat' -> Nat'
nAdd Nat'
lk Nat'
rk) ])

  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
  do (Nat'
rk, Type
b) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
aMul, Pat Type Nat'
aNat', forall a. Pat a a
__)
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
       case (Nat'
lk,Nat'
rk) of
         -- Inf * t = Inf ~~~>  t >= 1
         (Nat'
Inf,Nat'
Inf)    -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
>== Type
tOne ]

         -- K * t = Inf ~~~> t = Inf
         (Nat'
Inf,Nat Integer
_)  -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= Type
tInf ]

         -- Inf * t = 0 ~~~> t = 0
         (Nat Integer
0, Nat'
Inf) -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= Type
tZero ]

         -- Inf * t = K ~~~> ERR      (K /= 0)
         (Nat Integer
_k, Nat'
Inf) -> Solved
Unsolvable

         (Nat Integer
lk', Nat Integer
rk')
           -- 0 * t = K2 ~~> K2 = 0
           | Integer
rk' forall a. Eq a => a -> a -> Bool
== Integer
0 -> [Type] -> Solved
SolvedIf [ Nat' -> Type
tNat' Nat'
lk Type -> Type -> Type
=#= Type
tZero ]
              -- shouldn't happen, as `0 * t = t` should have been simplified

           -- K1 * t = K2 ~~> t = K2/K1
           | (Integer
q,Integer
0) <- forall a. Integral a => a -> a -> (a, a)
divMod Integer
lk' Integer
rk' -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= forall a. Integral a => a -> Type
tNum Integer
q ]
           | Bool
otherwise -> Solved
Unsolvable

  forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
  -- K1 == K2 ^^ t    ~~> t = logBase K2 K1
  do (Integer
rk, Type
b) <- forall thing pats res.
Matches thing pats res =>
thing -> pats -> Match res
matches Type
ty (Pat Type (Type, Type)
(|^|), Pat Type Integer
aNat, forall a. Pat a a
__)
     forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Nat'
lk of
                Nat'
Inf | Integer
rk forall a. Ord a => a -> a -> Bool
> Integer
1 -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= Type
tInf ]
                Nat Integer
n | Just (Integer
a,Bool
True) <- Integer -> Integer -> Maybe (Integer, Bool)
genLog Integer
n Integer
rk -> [Type] -> Solved
SolvedIf [ Type
b Type -> Type -> Type
=#= forall a. Integral a => a -> Type
tNum Integer
a]
                Nat'
_ -> Solved
Unsolvable

  -- XXX: Min, Max, etx
  -- 2  = min (10,y)  --> y = 2
  -- 2  = min (2,y)   --> y >= 2
  -- 10 = min (2,y)   --> impossible


-- | K1 * t1 + K2 * t2 + ... = K3 * t3 + K4 * t4 + ...
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst Type
l Type
r =
  do (Integer
lc,[(Integer, Type)]
ls) <- Pat Type (Integer, [(Integer, Type)])
matchLinear Type
l
     (Integer
rc,[(Integer, Type)]
rs) <- Pat Type (Integer, [(Integer, Type)])
matchLinear Type
r
     let d :: Integer
d = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 forall a. Integral a => a -> a -> a
gcd (Integer
lc forall a. a -> [a] -> [a]
: Integer
rc forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst ([(Integer, Type)]
ls forall a. [a] -> [a] -> [a]
++ [(Integer, Type)]
rs))
     forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Integer
d forall a. Ord a => a -> a -> Bool
> Integer
1)
     forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [forall {a}. Integral a => a -> a -> [(a, Type)] -> Type
build Integer
d Integer
lc [(Integer, Type)]
ls Type -> Type -> Type
=#= forall {a}. Integral a => a -> a -> [(a, Type)] -> Type
build Integer
d Integer
rc [(Integer, Type)]
rs])
  where
  build :: a -> a -> [(a, Type)] -> Type
build a
d a
k [(a, Type)]
ts   = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
tAdd (forall {a}. Integral a => a -> a -> Type
cancel a
d a
k) (forall a b. (a -> b) -> [a] -> [b]
map (forall {a}. Integral a => a -> (a, Type) -> Type
buildS a
d) [(a, Type)]
ts)
  buildS :: a -> (a, Type) -> Type
buildS a
d (a
k,Type
t) = Type -> Type -> Type
tMul (forall {a}. Integral a => a -> a -> Type
cancel a
d a
k) Type
t
  cancel :: a -> a -> Type
cancel a
d a
x     = forall a. Integral a => a -> Type
tNum (forall a. Integral a => a -> a -> a
div a
x a
d)


-- | @(t1 + t2 = Inf, fin t1)  ~~> t2 = Inf@
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf Ctxt
ctxt Type
l Type
r = Type -> Type -> Match Solved
check Type
l Type
r forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Type -> Type -> Match Solved
check Type
r Type
l
  where

  -- check for x = a + b /\ x = inf
  check :: Type -> Type -> Match Solved
check Type
x Type
y =
    do (Type
x1,Type
x2) <- Pat Type (Type, Type)
anAdd Type
x
       Pat Type ()
aInf Type
y

       let x1Fin :: Bool
x1Fin = Interval -> Bool
iIsFin (Map TVar Interval -> Type -> Interval
typeInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) Type
x1)
       let x2Fin :: Bool
x2Fin = Interval -> Bool
iIsFin (Map TVar Interval -> Type -> Interval
typeInterval (Ctxt -> Map TVar Interval
intervals Ctxt
ctxt) Type
x2)

       forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$!
         if | Bool
x1Fin ->
              [Type] -> Solved
SolvedIf [ Type
x2 Type -> Type -> Type
=#= Type
y ]

            | Bool
x2Fin ->
              [Type] -> Solved
SolvedIf [ Type
x1 Type -> Type -> Type
=#= Type
y ]

            | Bool
otherwise ->
              Solved
Unsolved



-- | Check for addition of constants to both sides of a relation.
--  @((K1 + K2) + t1) `R` (K1 + t2)  ~~>   (K2 + t1) `R` t2@
--
-- This relies on the fact that constants are floated left during
-- simplification.
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst :: (Type -> Type -> Type) -> Type -> Type -> Match Solved
tryAddConst Type -> Type -> Type
rel Type
l Type
r =
  do (Type
x1,Type
x2) <- Pat Type (Type, Type)
anAdd Type
l
     (Type
y1,Type
y2) <- Pat Type (Type, Type)
anAdd Type
r

     Integer
k1 <- Pat Type Integer
aNat Type
x1
     Integer
k2 <- Pat Type Integer
aNat Type
y1

     if Integer
k1 forall a. Ord a => a -> a -> Bool
> Integer
k2
        then forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ Type -> Type -> Type
tAdd (forall a. Integral a => a -> Type
tNum (Integer
k1 forall a. Num a => a -> a -> a
- Integer
k2)) Type
x2 Type -> Type -> Type
`rel` Type
y2 ])
        else forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ Type
x2 Type -> Type -> Type
`rel` Type -> Type -> Type
tAdd (forall a. Integral a => a -> Type
tNum (Integer
k2 forall a. Num a => a -> a -> a
- Integer
k1)) Type
y2 ])


-- | Check for situations where a unification variable is involved in
--   a sum of terms not containing additional unification variables,
--   and replace it with a solution and an inequality.
--   @s1 = ?a + s2 ~~> (?a = s1 - s2, s1 >= s2)@
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution Type
s1 Type
t =
  do (TVar
a,[Type]
xs) <- Pat Type (TVar, [Type])
matchLinearUnifier Type
t
     forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall t. FVS t => t -> Bool
noFreeVariables Type
s1)

     -- NB: matchLinearUnifier only matches if xs is nonempty
     let s2 :: Type
s2 = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Type -> Type -> Type
Simp.tAdd [Type]
xs
     forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Solved
SolvedIf [ TVar -> Type
TVar TVar
a Type -> Type -> Type
=#= (Type -> Type -> Type
Simp.tSub Type
s1 Type
s2), Type
s1 Type -> Type -> Type
>== Type
s2 ])


-- | Match a sum of the form @(s1 + ... + ?a + ... sn)@ where
--   @s1@ through @sn@ do not contain any free variables.
--
--   Note: a successful match should only occur if @s1 ... sn@ is
--   not empty.
matchLinearUnifier :: Pat Type (TVar,[Type])
matchLinearUnifier :: Pat Type (TVar, [Type])
matchLinearUnifier = [Type] -> Pat Type (TVar, [Type])
go []
 where
  go :: [Type] -> Pat Type (TVar, [Type])
go [Type]
xs Type
t =
    -- Case where a free variable occurs at the end of a sequence of additions.
    -- NB: match fails if @xs@ is empty
    do TVar
v <- Pat Type TVar
aFreeTVar Type
t
       forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [Type]
xs)
       forall (m :: * -> *) a. Monad m => a -> m a
return (TVar
v, [Type]
xs)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
    -- Next symbol is an addition
    do (Type
x, Type
y) <- Pat Type (Type, Type)
anAdd Type
t

        -- Case where a free variable occurs in the middle of an expression
       (do TVar
v <- Pat Type TVar
aFreeTVar Type
x
           forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall t. FVS t => t -> Bool
noFreeVariables Type
y)
           forall (m :: * -> *) a. Monad m => a -> m a
return (TVar
v, forall a. [a] -> [a]
reverse (Type
yforall a. a -> [a] -> [a]
:[Type]
xs))

        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
         -- Non-free-variable recursive case
        do forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall t. FVS t => t -> Bool
noFreeVariables Type
x)
           [Type] -> Pat Type (TVar, [Type])
go (Type
xforall a. a -> [a] -> [a]
:[Type]
xs) Type
y)


-- | Is this a sum of products, where the products have constant coefficients?
matchLinear :: Pat Type (Integer, [(Integer,Type)])
matchLinear :: Pat Type (Integer, [(Integer, Type)])
matchLinear = (Integer, [(Integer, Type)])
-> Pat Type (Integer, [(Integer, Type)])
go (Integer
0, [])
  where
  go :: (Integer, [(Integer, Type)])
-> Pat Type (Integer, [(Integer, Type)])
go (Integer
c,[(Integer, Type)]
ts) Type
t =
    do Integer
n <- Pat Type Integer
aNat Type
t
       forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
n forall a. Num a => a -> a -> a
+ Integer
c, [(Integer, Type)]
ts)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
    do (Type
x,Type
y) <- Pat Type (Type, Type)
aMul Type
t
       Integer
n     <- Pat Type Integer
aNat Type
x
       forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
c, (Integer
n,Type
y) forall a. a -> [a] -> [a]
: [(Integer, Type)]
ts)
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
    do (Type
l,Type
r) <- Pat Type (Type, Type)
anAdd Type
t
       (Integer
c',[(Integer, Type)]
ts') <- (Integer, [(Integer, Type)])
-> Pat Type (Integer, [(Integer, Type)])
go (Integer
c,[(Integer, Type)]
ts) Type
l
       (Integer, [(Integer, Type)])
-> Pat Type (Integer, [(Integer, Type)])
go (Integer
c',[(Integer, Type)]
ts') Type
r