module Cryptol.TypeCheck.Solver.Numeric
( cryIsEqual, cryIsNotEqual, cryIsGeq
) where
import Control.Applicative(Alternative(..))
import Control.Monad (guard,mzero)
import Data.Foldable (asum)
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual ctxt t1 t2 =
matchDefault Unsolved $
(pBin PEqual (==) t1 t2)
<|> (aNat' t1 >>= tryEqK ctxt t2)
<|> (aNat' t2 >>= tryEqK ctxt t1)
<|> (aTVar t1 >>= tryEqVar t2)
<|> (aTVar t2 >>= tryEqVar t1)
<|> ( guard (t1 == t2) >> return (SolvedIf []))
<|> tryEqMin t1 t2
<|> tryEqMin t2 t1
<|> tryEqMulConst t1 t2
<|> tryEqAddInf ctxt t1 t2
<|> tryAddConst (=#=) t1 t2
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
matchDefault Unsolved $
(pBin PGeq (>=) t1 t2)
<|> (aNat' t1 >>= tryGeqKThan i t2)
<|> (aNat' t2 >>= tryGeqThanK i t1)
<|> (aTVar t2 >>= tryGeqThanVar i t1)
<|> tryGeqThanSub i t1 t2
<|> (geqByInterval i t1 t2)
<|> (guard (t1 == t2) >> return (SolvedIf []))
<|> tryAddConst (>==) t1 t2
<|> tryMinIsGeq t1 t2
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin tf p t1 t2 =
Unsolvable <$> anError KNum t1
<|> Unsolvable <$> anError KNum t2
<|> (do x <- aNat' t1
y <- aNat' t2
return $ if p x y
then SolvedIf []
else Unsolvable $ TCErrorMessage
$ "Unsolvable constraint: " ++
show (pp (TCon (PC tf) [ tNat' x, tNat' y ])))
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =
do (a,b) <- aMul ty
m <- aNat' a
return $ SolvedIf
$ case m of
Inf -> [ b =#= tZero ]
Nat 0 -> []
Nat k -> [ tNum (div n k) >== b ]
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ])
tryGeqThanK _ t (Nat k) =
do (a,b) <- anAdd t
n <- aNat a
return $ SolvedIf $ if n >= k
then []
else [ b >== tNum (k n) ]
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =
do (a,_) <- (|-|) y
guard (x == a)
return (SolvedIf [])
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
do (a,b) <- anAdd ty
let check y = do x' <- aTVar y
guard (x == x')
return (SolvedIf [])
check a <|> check b
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
let ix = typeInterval ctxt x
iy = typeInterval ctxt y
in case (iLower ix, iUpper iy) of
(l,Just n) | l >= n -> return (SolvedIf [])
_ -> mzero
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq t1 t2 =
do (a,b) <- aMin t1
k1 <- aNat a
k2 <- aNat t2
return $ if k1 >= k2
then SolvedIf [ b >== t2 ]
else Unsolvable $ TCErrorMessage $
show k1 ++ " can't be greater than " ++ show k2
tryEqMin :: Type -> Type -> Match Solved
tryEqMin x y =
do (a,b) <- aMin x
let check m1 m2 = do guard (m1 == y)
return $ SolvedIf [ m2 >== m1 ]
check a b <|> check b a
tryEqVar :: Type -> TVar -> Match Solved
tryEqVar ty x =
(do (k,tv) <- matches ty (anAdd, aNat, aTVar)
guard (tv == x && k >= 1)
return $ SolvedIf [ TVar x =#= tInf ]
)
<|>
(do (l,r) <- aMin ty
let check this other =
do (k,x') <- matches this (anAdd, aNat', aTVar)
guard (x == x' && k >= Nat 1)
return $ SolvedIf [ TVar x =#= other ]
check l r <|> check r l
)
<|>
(do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
guard (k >= 1)
let check a b = do x' <- aTVar a
guard (x' == x)
return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
check l r <|> check r l
)
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK ctxt ty lk =
do guard (lk == Inf)
(a,b) <- anAdd ty
let check x y = do guard (iIsFin (typeInterval ctxt x))
return $ SolvedIf [ y =#= tInf ]
check a b <|> check b a
<|>
do (rk, b) <- matches ty (anAdd, aNat', __)
return $
case nSub lk rk of
Nothing -> Unsolvable
$ TCErrorMessage
$ "Adding " ++ show rk ++ " will always exceed "
++ show lk
Just r -> SolvedIf [ b =#= tNat' r ]
<|>
do (rk, b) <- matches ty (aMul, aNat', __)
return $
case (lk,rk) of
(Inf,Inf) -> SolvedIf [ b >== tOne ]
(Inf,Nat _) -> SolvedIf [ b =#= tInf ]
(Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
(Nat k, Inf) -> Unsolvable
$ TCErrorMessage
$ show k ++ " /= inf * anything"
(Nat lk', Nat rk')
| rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
| (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
| otherwise ->
Unsolvable
$ TCErrorMessage
$ show lk ++ " /= " ++ show rk ++ " * anything"
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst l r =
do (l1,l2) <- aMul l
(r1,r2) <- aMul r
asum [ do l1' <- aNat l1
r1' <- aNat r1
return (build l1' l2 r1' r2)
, do l2' <- aNat l2
r1' <- aNat r1
return (build l2' l1 r1' r2)
, do l1' <- aNat l1
r2' <- aNat r2
return (build l1' l2 r2' r1)
, do l2' <- aNat l2
r2' <- aNat r2
return (build l2' l1 r2' r1)
]
where
build lk l' rk r' =
let d = gcd lk rk
lk' = lk `div` d
rk' = rk `div` d
in if d == 1
then Unsolved
else (SolvedIf [ tMul (tNum lk') l' =#= tMul (tNum rk') r' ])
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf ctxt l r = check l r <|> check r l
where
check x y =
do (x1,x2) <- anAdd x
aInf y
let x1Fin = iIsFin (typeInterval ctxt x1)
let x2Fin = iIsFin (typeInterval ctxt x2)
return $!
if | x1Fin ->
SolvedIf [ x2 =#= y ]
| x2Fin ->
SolvedIf [ x1 =#= y ]
| otherwise ->
Unsolved
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst rel l r =
do (x1,x2) <- anAdd l
(y1,y2) <- anAdd r
k1 <- aNat x1
k2 <- aNat y1
if k1 > k2
then return (SolvedIf [ tAdd (tNum (k1 k2)) x2 `rel` y2 ])
else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 k1)) y2 ])