{- CAO Compiler
Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see . -}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{- |
Module : $Header$
Description : Evaluation of index language
Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
License : GPL
Maintainer : Paulo Silva
Stability : experimental
Portability : non-portable ()
--}
module Language.CAO.Index.Eval
( evalCond
, evalExpr
) where
import Data.List
import Language.CAO.Index
import Language.CAO.Semantics.Integer
import Language.CAO.Semantics.Bool
--------------------------------------------------------------------------------
-- This implements the properties of the several boolean operators,
-- except conjuntion. It does not use equality on expressions, only
-- on variables.
truthTable :: Eq id => ICond id -> ICond id
truthTable (IBoolOp op (IBool b1) (IBool b2))
= IBool $ mapIBOp op b1 b2
truthTable (IBoolOp IOr (IBool b1) b2)
| b1 = IBool True
| otherwise = truthTable b2
truthTable (IBoolOp IOr b1 (IBool b2))
| b2 = IBool True
| otherwise = truthTable b1
truthTable (IBoolOp IXor (IBool b1) b2)
| b1 = truthTable $ deMorgan b2
| otherwise = truthTable b2
truthTable (IBoolOp IXor b1 (IBool b2))
| b2 = truthTable $ deMorgan b1
| otherwise = truthTable b1
truthTable (IBoolOp IXor (IBInd b1) (IBInd b2))
| b1 == b2 = IBool False
| otherwise = IBool True
truthTable (IBoolOp _ (IBInd b1) (IBInd b2)) -- Idempotence
| b1 == b2 = IBInd b1
truthTable e = e
--------------------------------------------------------------------------------
-- Application of deMorgan rules
deMorgan :: ICond id -> ICond id
deMorgan (IBool b) = IBool $ not b
deMorgan (INot c) = c
deMorgan (IBoolOp IOr c1 c2) = IAnd [deMorgan c1, deMorgan c2]
deMorgan (IBoolOp IXor c1 c2) = aux c1 c2
where
aux (IBool b) e = IBoolOp IXor (IBool $ not b) e
aux e (IBool b) = IBoolOp IXor e (IBool $ not b)
aux (INot b) e = IBoolOp IXor b e
aux e (INot b) = IBoolOp IXor e b
aux i@(IBInd _) e = IBoolOp IXor (INot i) e
aux e i@(IBInd _) = IBoolOp IXor e (INot i)
aux e1 e2 = IBoolOp IXor (deMorgan e1) e2
deMorgan (IAnd lc) = andToOr $ map deMorgan lc
where
andToOr [] = error ": unexpected case"
andToOr [x] = x
andToOr (x:xs) = IBoolOp IOr x (andToOr xs)
deMorgan i = INot i
--------------------------------------------------------------------------------
evalCond :: (Ord id, Eq id) => ICond id -> ICond id
evalCond (INot c) = case deMorgan (evalCond c) of
IAnd l -> flatAnd l
c' -> truthTable c'
evalCond (IAnd l) = flatAnd $ map evalCond l
-- Canonical form for Nested expressions
evalCond (IBoolOp op c1 c2) = case (evalCond c1, evalCond c2) of
(l1@(IAnd _), l2) -> flatAnd $ distrOr l1 l2
(l1, l2@(IAnd _)) -> flatAnd $ distrOr l1 l2
(l1, l2) -> truthTable $ IBoolOp op l1 l2
evalCond (ILeq e) = case evalExpr e of
IInt i -> IBool $ 0 <= i
e' -> ILeq e'
evalCond (IEq e) = case evalExpr e of
IInt i -> IBool $ 0 == i
e' -> IEq e'
evalCond c = c
--------------------------------------------------------------------------------
distrOr :: ICond id -> ICond id -> [ICond id]
distrOr (IAnd l1) (IAnd l2) = concatMap (distrOr' l1) l2
distrOr c (IAnd l2) = distrOr' l2 c
distrOr (IAnd l1) c = distrOr' l1 c
distrOr _ _ = error ": not expected"
distrOr' :: [ICond id] -> ICond id -> [ICond id]
distrOr' l c = map (IBoolOp IOr c) l
--------------------------------------------------------------------------------
-- Remove True
-- Reduce to False
-- Bring out nested And
flatAnd :: Eq id => [ICond id] -> ICond id
flatAnd c = let
(v, var, i) = foldr aux (True, [], []) c
in if v && not (null i && null var)
then IAnd (nub var ++ i)
else IBool v
where
aux (IBool False) _ = (False, [], [])
aux (IBool True) r = r
aux (IAnd l) (v, vs, r) = case flatAnd l of
IAnd l' -> (v, vs, l' ++ r)
IBool False -> (False, [], [])
IBool True -> (v, vs, r)
_ -> error "flatAnd.aux: Not expected case"
aux i@(IBInd _) (v, vs, r) = (v, i : vs, r)
aux x (v, vs, r) = case truthTable x of
IBool False -> (False, [], [])
IBool True -> (v, vs, r)
x' -> (v, vs, x' : r)
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
-- Partially inspired by
-- Producing Proofs from an Arithmetic Decision Procedure in Elliptical LF
-- Aaron Stump, Clark W. Barrett, and David L. Dill
{-
Flat form:
1. All sums are non-empty
2. The first element of a sum is always a constant
3. Variables are never alone. They are always part of a product 1 * v
4. Symmetric is moved downwards to values and variables
5. There are not nested sums
6. The outer symbol of a flatExpr is always a ISum
7. Operations on literals are always computed
-}
--------------------------------------------------------------------------------
evalExpr :: (Eq id, Ord id) => IExpr id -> IExpr id
evalExpr = canonicalExpr . flatExpr
flatExpr :: Eq id => IExpr id -> IExpr id
flatExpr i@(IInt _) = ISum [i]
flatExpr i@(IInd _) = ISum [IInt 0, IArith ITimes (IInt 1) i]
flatExpr (ISym (IInt n)) = resInt $ negate n -- shortcut
flatExpr (ISym t) = distrSym $ flatExpr t
flatExpr (ISum l) = ISum $ flatSum $ map flatExpr l
flatExpr (IArith IMinus (IInt n) (IInt n')) = resInt $ n - n'
flatExpr (IArith IMinus (IInd i) (IInd i'))
| i == i' = resInt 0
flatExpr (IArith IMinus i1 i2) =
ISum $ flatSum [flatExpr i1, distrSym $ flatExpr i2]
flatExpr (IArith ITimes (IInt n) (IInt n')) = resInt $ n * n'
flatExpr e@(IArith ITimes _ _) = flatTimes e
flatExpr (IArith IDiv (IInt n) (IInt n')) = resInt $ div n n'
flatExpr (IArith IDiv p q) = let
ISum p' = flatExpr p
ISum q' = flatExpr q
in ISum [listDiv p' q']
flatExpr (IArith IPower (IInt b) (IInt e)) = resInt $ b^e
flatExpr (IArith IPower b e) =
ISum [IInt 0, IArith ITimes (IInt 1) (IArith IPower (flatExpr b) (flatExpr e))]
flatExpr (IArith IModOp (IInt a) (IInt b)) = resInt $ mod a b
flatExpr _ = error "TODO: flatExpr: "
resInt :: Integer -> IExpr id
resInt n = ISum [IInt n]
flatTimes :: Eq id => IExpr id -> IExpr id
flatTimes e@(IArith ITimes _ _) = let
(ci, si, mi) = sepTimes e
mm = toMult (product ci) mi
(sc, ssi) = iTimesConcat $ map flatExpr si
(pc, pe) = iTimesLst [mm] ssi
(pc', pe') = iTimesLst [mm] sc
sumCi = constSum $ pc ++ pc'
in ISum $ sumCi : pe ++ pe'
flatTimes _ = error ": not expected"
toMult :: Integer -> [IExpr id] -> IExpr id
toMult n [] = IInt n
toMult n xs@(_:_) = IArith ITimes (IInt n) (aux xs)
where
aux [] = error ": not expected"
aux [e] = e
aux (e:es) = IArith ITimes e (aux es)
sepTimes :: Eq id => IExpr id -> ([Integer], [IExpr id], [IExpr id])
sepTimes (IInt n) = ([n], [], [])
sepTimes i@(IInd _) = ([], [], [i])
sepTimes s@(ISum _) = ([], [s], [])
sepTimes (IArith ITimes i1 i2) = let
(ci1, si1, mi1) = sepTimes i1
(ci2, si2, mi2) = sepTimes i2
in (ci1 ++ ci2, si1 ++ si2, mi1 ++ mi2)
sepTimes (IArith IPower (IInt n) (IInt e)) = ([n ^ e], [], [])
sepTimes e@(IArith IMinus _ _) = sepTimes (flatExpr e)
sepTimes (ISym e) = sepTimes (flatExpr e)
sepTimes s@(IArith IDiv _ _) = ([], [s], [])
sepTimes _ = error "<> IExpr id
constSum = constOp ((+), 0)
constOp :: (Integer -> Integer -> Integer, Integer) -> [IExpr id] -> IExpr id
constOp (f, n) = foldr aux (IInt n)
where
aux (IInt m) (IInt res) = IInt (f m res)
aux _ _ = error ": not expected"
iTimesConcat :: [IExpr id] -> ([IExpr id], [IExpr id])
iTimesConcat [] = ([IInt 0],[IInt 1])
iTimesConcat [ISum x] = ([head x], tail x)
iTimesConcat (ISum x:xs) = let
(cs, xs') = iTimesConcat xs
(c, i) = iTimesLst x cs
(c', i') = iTimesLst x xs'
in ([constSum $ c ++ c'], i ++ i')
iTimesConcat (_:_) = error ": not expected"
iTimesLst :: [IExpr id] -> [IExpr id] -> ([IExpr id], [IExpr id])
iTimesLst [] _ = ([], [])
iTimesLst [x] xr = let
(nl, ol) = unzip $ map (iTimes x) xr
in ([constSum $ concat nl], concat ol)
iTimesLst (x:xl) xr = let
(nl, ol) = iTimesLst xl xr
(nl', ol') = unzip $ map (iTimes x) xr
in ([constSum (nl ++ concat nl')], ol ++ concat ol')
iTimes :: IExpr id -> IExpr id -> ([IExpr id], [IExpr id])
-- Constant * Constant
iTimes (IInt n) (IInt n') = ([IInt $ n * n'], [])
-- Constant * Variable
iTimes (IInt 0) (IInd _) = ([], [])
iTimes (IInt n) (IInd i) = ([], [IArith ITimes (IInt n) (IInd i)])
iTimes (IInd _) (IInt 0) = ([], [])
iTimes (IInd i) (IInt n) = ([], [IArith ITimes (IInt n) (IInd i)])
-- Constant * Product
iTimes (IInt 0) (IArith ITimes (IInt _) _) = ([], [])
iTimes (IInt n) (IArith ITimes (IInt c) i) = ([], [IArith ITimes (IInt $ c * n) i])
iTimes (IArith ITimes (IInt _) _) (IInt 0) = ([], [])
iTimes (IArith ITimes (IInt c) i) (IInt n) = ([], [IArith ITimes (IInt $ c * n) i])
-- Variable * Product
iTimes (IInd i) (IArith ITimes (IInt c) i') = ([], [IArith ITimes (IInt c) (IArith ITimes (IInd i) i')])
iTimes (IArith ITimes (IInt c) i') (IInd i) = ([], [IArith ITimes (IInt c) (IArith ITimes (IInd i) i')])
-- Product * Product
iTimes (IArith ITimes (IInt c) i) (IArith ITimes (IInt c') i') =
([], [IArith ITimes (IInt $ c * c') (IArith ITimes i i')]) -- TODO: Not in the right form
iTimes (IArith ITimes i1 i2) e2@(IArith ITimes _ _) =
([], [IArith ITimes (IInt 1) $ IArith ITimes i1 (IArith ITimes i2 e2)])
-- Produce * Division
iTimes l@(IArith IDiv _ _) r@(IInt _) = ([], [IArith ITimes l r])
iTimes l@(IInt _) r@(IArith IDiv _ _) = ([], [IArith ITimes l r])
-- Error
iTimes _ _ = error ": not expected"
--------------------------------------------------------------------------------
-- Expectes a flat expression
distrSym :: IExpr id -> IExpr id
distrSym e = case e of
IInt n -> IInt (negate n)
ISym i -> i
ISum l -> ISum $ map distrSym l -- always the entry point
IArith ITimes (IInt c) i -> IArith ITimes (IInt (negate c)) i
IArith ITimes c _ -> IArith ITimes (distrSym c) e
IArith IDiv (IInt c) i -> IArith IDiv (IInt (negate c)) i
IArith IDiv c (IInt i) -> IArith IDiv c (IInt (negate i))
IArith IDiv c i -> IArith IDiv (distrSym c) i
IArith IModOp _ _ -> error ": <> mod"
IArith IPower _ _ -> error ": should never reach a power"
IInd _ -> error ": should never reach a index variable"
_ -> error "<>: missing case: "
--------------------------------------------------------------------------------
listDiv :: [IExpr id] -> [IExpr id] -> IExpr id
listDiv [IInt l] [IInt r] = IInt $ mapIAOp IDiv l r
listDiv [IInt l] [IInd r] = IArith IDiv (IInt l) (IInd r)
listDiv [IInd l] [IInt r] = IArith IDiv (IInd l) (IInt r)
listDiv [IInd l] [IInd r] = IArith IDiv (IInd l) (IInd r)
listDiv l r = IArith IDiv (ISum l) (ISum r)
--------------------------------------------------------------------------------
-- This may not be enough to bring them to the top level
flatSum :: [IExpr id] -> [IExpr id]
flatSum l = let
(c, l') = aux l
in IInt (sum c) : concat l'
where
aux :: [IExpr id] -> ([Integer], [[IExpr id]])
aux [] = ([], [])
aux (ISum (IInt n:l'):ls) = let
(ns, ls') = aux ls
in (n : ns, l' : ls')
aux (ISum l':ls) = let
(ns, ls') = aux ls
in (ns, l' : ls')
aux (x:ls) = let
(ns, ls') = aux ls
in (ns, [x]:ls')
--------------------------------------------------------------------------------
cmp :: Ord id => IExpr id -> IExpr id -> Ordering
cmp (IInt _) _ = LT
cmp _ (IInt _) = GT
cmp (IArith ITimes (IInt _) (IInd i)) (IArith ITimes (IInt _) (IInd i')) =
compare i i'
cmp (IArith op _ _) (IArith op' _ _) =
cmpIAOp op op'
cmp (ISum _) (IArith {}) = LT
cmp (IArith {}) (ISum _) = GT
cmp (ISum l) (ISum l') = cmpList l l'
cmp _ _ = error "Ordering: not expected: "
-- Lexicographic order
cmpList :: Ord id => [IExpr id] -> [IExpr id] -> Ordering
cmpList [] [] = EQ
cmpList [] _ = LT
cmpList _ [] = GT
cmpList (x:xs) (x':xs') = case cmp x x' of
EQ -> cmpList xs xs'
r -> r
cmpIAOp :: IAOp -> IAOp -> Ordering
cmpIAOp ITimes _ = LT
cmpIAOp _ ITimes = GT
cmpIAOp IDiv _ = LT
cmpIAOp _ IDiv = GT
cmpIAOp IPower _ = LT
cmpIAOp _ IPower = GT
cmpIAOp IModOp _ = GT
cmpIAOp _ _ = error ": not expected"
-- TODO: non-linear coeficients may need this as well
canonicalExpr :: Ord id => IExpr id -> IExpr id
canonicalExpr (ISum l) = revert $ ISum $ combine $ sortBy cmp l
where
combine :: Eq id => [IExpr id] -> [IExpr id]
combine [] = []
combine [i] = [i]
combine (e1@(IArith ITimes (IInt n) (IInd i)): e2@(IArith ITimes (IInt n') (IInd i')) : xs) = let r = n + n'
in if i == i'
then if r /= 0
then combine (IArith ITimes (IInt r) (IInd i) : xs)
else combine xs
else e1 : combine (e2 : xs)
combine (x:xs) = x : combine xs
revert :: IExpr id -> IExpr id
revert (ISum [i@(IInt _)]) = i
revert (ISum [IInt 0, IArith ITimes (IInt 1) v]) = v
revert (ISum (IInt 0 : xs)) = ISum $ concatMap aux xs
revert (ISum lst) = ISum $ concatMap aux lst
revert lst = lst
aux (IArith ITimes (IInt 1) e) = [e]
aux (IArith ITimes (IInt 0) _) = []
aux e = [e]
canonicalExpr e = e
--------------------------------------------------------------------------------
mapIBOp :: IBOp -> Bool -> Bool -> Bool
mapIBOp IOr = boolOr
mapIBOp IXor = boolXor
--------------------------------------------------------------------------------
mapIAOp :: IAOp -> Integer -> Integer -> Integer
mapIAOp IMinus = integerMinus
mapIAOp ITimes = integerTimes
mapIAOp IPower = integerPower
mapIAOp IDiv = integerDiv
mapIAOp IModOp = integerMod