module Cryptol.TypeCheck.Solver.Numeric.Simplify1 (propToProp', ppProp') where
import Cryptol.TypeCheck.Solver.Numeric.SimplifyExpr
(crySimpExpr, splitSum, normSum, Sign(..))
import Cryptol.TypeCheck.Solver.Numeric.AST
import Cryptol.TypeCheck.Solver.InfNat(genLog,rootExact)
import Cryptol.Utils.Misc ( anyJust )
import Cryptol.Utils.Panic
import Cryptol.Utils.PP
import Control.Monad ( liftM2 )
import Data.Maybe ( fromMaybe )
import qualified Data.Map as Map
import Data.Either(partitionEithers)
data Atom = AFin Name
| AGt Expr Expr
| AEq Expr Expr
deriving Eq
type I = IfExpr' Atom
propToProp' :: Prop -> I Bool
propToProp' prop =
case prop of
Fin e -> pFin e
x :== y -> pEq x y
x :>= y -> pGeq x y
x :> y -> pGt x y
x :>: y -> pAnd (pFin x) (pAnd (pFin y) (pGt x y))
x :==: y -> pAnd (pFin x) (pAnd (pFin y) (pEq x y))
p :&& q -> pAnd (propToProp' p) (propToProp' q)
p :|| q -> pOr (propToProp' p) (propToProp' q)
Not p -> pNot (propToProp' p)
PFalse -> pFalse
PTrue -> pTrue
instance (Eq a, HasVars a) => HasVars (I a) where
apSubst su prop =
case prop of
Impossible -> Nothing
Return _ -> Nothing
If a t e ->
case apSubstAtom su a of
Nothing -> do [x,y] <- branches
return (If a x y)
Just a' -> Just $ fromMaybe (pIf a' t e)
$ do [x,y] <- branches
return (pIf a' x y)
where branches = anyJust (apSubst su) [t,e]
apSubstAtom :: Subst -> Atom -> Maybe (I Bool)
apSubstAtom su atom =
case atom of
AFin x -> do e <- Map.lookup x su
return (pFin e)
AEq e1 e2 -> do [x,y] <- anyJust (apSubst su) [e1,e2]
return (pEq x y)
AGt e1 e2 -> do [x,y] <- anyJust (apSubst su) [e1,e2]
return (pGt x y)
ppAtom :: Atom -> Doc
ppAtom atom =
case atom of
AFin x -> text "fin" <+> ppName x
AGt x y -> ppExpr x <+> text ">" <+> ppExpr y
AEq x y -> ppExpr x <+> text "=" <+> ppExpr y
ppProp' :: I Bool -> Doc
ppProp' = ppIf ppAtom (text . show)
pFalse :: I Bool
pFalse = Return False
pTrue :: I Bool
pTrue = Return True
pNot :: I Bool -> I Bool
pNot p =
case p of
Impossible -> Impossible
Return a -> Return (not a)
If c t e -> If c (pNot t) (pNot e)
pAnd :: I Bool -> I Bool -> I Bool
pAnd p q = pIf p q pFalse
pOr :: I Bool -> I Bool -> I Bool
pOr p q = pIf p pTrue q
mkIf :: (Eq a, HasVars a) => Atom -> I a -> I a -> I a
mkIf a t e
| t == e = t
| otherwise = case a of
AFin x -> If a (pKnownFin x t) (pKnownInf x e)
_ | If b@(AFin y) _ _ <- t -> If b (mkFinIf y) (mkInfIf y)
| If b@(AFin y) _ _ <- e -> If b (mkFinIf y) (mkInfIf y)
AEq x' y'
| x == y -> t
| otherwise -> If (AEq x y) t e
where (x,y) = balanceEq x' y'
_ -> If a t e
where
mkFinIf y = mkIf a (pKnownFin y t) (pKnownFin y e)
mkInfIf y = case apSubstAtom (Map.singleton y (K Inf)) a of
Nothing -> mkIf a (pKnownInf y t) (pKnownInf y t)
Just a' -> pIf a' (pKnownInf y t) (pKnownInf y t)
pIf :: (Eq a, HasVars a) => I Bool -> I a -> I a -> I a
pIf c t e =
case c of
Impossible -> Impossible
Return True -> t
Return False -> e
If p t1 e1
| t2 == e2 -> t2
| otherwise -> mkIf p t2 e2
where
t2 = pIf t1 t e
e2 = pIf e1 t e
pAtom :: Atom -> I Bool
pAtom p = do a <- case p of
AFin _ -> return p
AEq x y -> bin AEq x y
AGt x y -> bin AGt x y
If a pTrue pFalse
where
prep x = do y <- eNoInf x
case y of
K Inf -> Impossible
_ -> return (crySimpExpr y)
bin f x y = liftM2 f (prep x) (prep y)
pEq :: Expr -> Expr -> I Bool
pEq x (K (Nat n)) = pIsNat n x
pEq (K (Nat n)) y = pIsNat n y
pEq x y = pIf (pInf x) (pInf y) (pAnd (pFin y) (pAtom (AEq x y)))
pGeq :: Expr -> Expr -> I Bool
pGeq x y = pIf (pInf x) pTrue (pAnd (pFin y) (pAtom (AGt (one :+ x) y)))
pGt :: Expr -> Expr -> I Bool
pGt x y = pNot (pGeq y x)
pFin :: Expr -> I Bool
pFin expr =
case expr of
K Inf -> pFalse
K (Nat _) -> pTrue
Var x -> pAtom (AFin x)
t1 :+ t2 -> pAnd (pFin t1) (pFin t2)
t1 :- _ -> pFin t1
t1 :* t2 -> pIf (pInf t1) (pEq t2 zero)
$ pIf (pInf t2) (pEq t1 zero)
$ pTrue
Div t1 _ -> pFin t1
Mod _ _ -> pTrue
t1 :^^ t2 -> pIf (pInf t1) (pEq t2 zero)
$ pIf (pInf t2) (pOr (pEq t1 zero) (pEq t1 one))
$ pTrue
Min t1 t2 -> pOr (pFin t1) (pFin t2)
Max t1 t2 -> pAnd (pFin t1) (pFin t2)
Width t1 -> pFin t1
LenFromThen _ _ _ -> pTrue
LenFromThenTo _ _ _ -> pTrue
pInf :: Expr -> I Bool
pInf = pNot . pFin
pIsNat :: Integer -> Expr -> I Bool
pIsNat n expr =
case expr of
K Inf -> pFalse
K (Nat m) -> if m == n then pTrue else pFalse
Var _ -> nothing
K (Nat m) :+ e2 -> if n < m then pFalse
else pIsNat (n m) e2
x :+ y
| n == 0 -> pAnd (pIsNat 0 x) (pIsNat 0 y)
| n == 1 -> pOr (pAnd (pIsNat 0 x) (pIsNat 1 y))
(pAnd (pIsNat 1 x) (pIsNat 0 y))
| otherwise -> nothing
e1 :- e2 -> pEq (K (Nat n) :+ e1) e2
K (Nat m) :* e2 ->
if m == 0
then if n == 0 then pTrue else pFalse
else case divMod n m of
(q,r) -> if r == 0 then pIsNat q e2
else pFalse
e1 :* e2
| n == 0 -> pOr (pIsNat 0 e1) (pIsNat 0 e2)
| n == 1 -> pAnd (pIsNat 1 e1) (pIsNat 1 e2)
| otherwise -> nothing
Div x y -> pAnd (pGt (one :+ x) (K (Nat n) :* y))
(pGt (K (Nat (n + 1)) :* y) x)
Mod _ _ -> nothing
K (Nat m) :^^ y -> case genLog n m of
Just (a, exact) | exact -> pIsNat a y
_ -> pFalse
x :^^ K (Nat m) -> case rootExact n m of
Just a -> pIsNat a x
Nothing -> pFalse
x :^^ y
| n == 0 -> pAnd (pIsNat 0 x) (pGt y zero)
| n == 1 -> pOr (pIsNat 1 x) (pIsNat 0 y)
| otherwise -> nothing
Min x y
| n == 0 -> pOr (pIsNat 0 x) (pIsNat 0 y)
| otherwise -> pOr (pAnd (pIsNat n x) (pGt y (K (Nat (n 1)))))
(pAnd (pIsNat n y) (pGt x (K (Nat (n 1)))))
Max x y -> pOr (pAnd (pIsNat n x) (pGt (K (Nat (n + 1))) y))
(pAnd (pIsNat n y) (pGt (K (Nat (n + 1))) y))
Width x
| n == 0 -> pIsNat 0 x
| otherwise -> pAnd (pGt x (K (Nat (2^(n1) 1))))
(pGt (K (Nat (2 ^ n))) x)
x ->
panic "Cryptol.TypeCheck.Solver.Numeric.Simplify1.pIsNat"
[ "unexpected expression ", show x ]
where
nothing = pAnd (pFin expr) (pAtom (AEq expr (K (Nat n))))
_pIsGtThanNat :: Integer -> Expr -> I Bool
_pIsGtThanNat = undefined
_pNatIsGtThan :: Integer -> Expr -> I Bool
_pNatIsGtThan = undefined
pKnownFin :: (HasVars a, Eq a) => Name -> I a -> I a
pKnownFin x prop =
case prop of
If (AFin y) t _
| x == y -> pKnownFin x t
If p t e -> mkIf p (pKnownFin x t) (pKnownFin x e)
_ -> prop
pKnownInf :: (Eq a, HasVars a) => Name -> I a -> I a
pKnownInf x prop = fromMaybe prop (apSubst (Map.singleton x (K Inf)) prop)
balanceEq :: Expr -> Expr -> (Expr, Expr)
balanceEq (K (Nat a) :+ e1) (K (Nat b) :+ e2)
| a >= b = balanceEq e2 (K (Nat (a b)) :+ e1)
| otherwise = balanceEq e1 (K (Nat (b a)) :+ e2)
balanceEq e1 e2
| not (null neg_es1 && null neg_es2) = balanceEq (mk neg_es2 pos_es1)
(mk neg_es1 pos_es2)
where
(pos_es1, neg_es1) = partitionEithers (map classify (splitSum e1))
(pos_es2, neg_es2) = partitionEithers (map classify (splitSum e2))
classify (sign,e) = case sign of
Pos -> Left e
Neg -> Right e
mk as bs = normSum (foldr (:+) zero (as ++ bs))
balanceEq x y = (x,y)
eNoInf :: Expr -> I Expr
eNoInf expr =
case expr of
x :* y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(K Inf, K Inf) -> return inf
(K Inf, _) -> pIf (pEq y' zero) (return zero) (return inf)
(_, K Inf) -> pIf (pEq x' zero) (return zero) (return inf)
_ -> return (x' :* y')
x :^^ y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(K Inf, K Inf) -> return inf
(K Inf, _) -> pIf (pEq y' zero) (return one) (return inf)
(_, K Inf) -> pIf (pEq x' zero) (return zero)
$ pIf (pEq x' one) (return one)
$ return inf
_ -> return (x' :^^ y')
K _ -> return expr
Var _ -> return expr
x :+ y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(K Inf, _) -> return inf
(_, K Inf) -> return inf
_ -> return (x' :+ y')
x :- y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(_, K Inf) -> Impossible
(K Inf, _) -> return inf
_ -> return (x' :- y')
Div x y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(K Inf, _) -> Impossible
(_, K Inf) -> return zero
_ -> return (Div x' y')
Mod x y ->
do x' <- eNoInf x
pIf (pFin y)
(do y' <- eNoInf y
case (x',y') of
(K Inf, _) -> Impossible
(_, K Inf) -> Impossible
_ -> return (Mod x' y')
)
(return x')
Min x y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x',y') of
(K Inf, _) -> return y'
(_, K Inf) -> return x'
_ -> return (Min x' y')
Max x y ->
do x' <- eNoInf x
y' <- eNoInf y
case (x', y') of
(K Inf, _) -> return inf
(_, K Inf) -> return inf
_ -> return (Max x' y')
Width x ->
do x' <- eNoInf x
case x' of
K Inf -> return inf
_ -> return (Width x')
LenFromThen x y w -> fun3 LenFromThen x y w
LenFromThenTo x y z -> fun3 LenFromThenTo x y z
where
fun3 f x y z =
do x' <- eNoInf x
y' <- eNoInf y
z' <- eNoInf z
case (x',y',z') of
(K Inf, _, _) -> Impossible
(_, K Inf, _) -> Impossible
(_, _, K Inf) -> Impossible
_ -> return (f x' y' z')