{- CAO Compiler Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} {- | Module : $Header$ Description : Evaluation of index language Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho License : GPL Maintainer : Paulo Silva Stability : experimental Portability : non-portable () --} module Language.CAO.Index.Eval ( evalCond , evalExpr ) where import Data.List import Language.CAO.Index import Language.CAO.Semantics.Integer import Language.CAO.Semantics.Bool -------------------------------------------------------------------------------- -- This implements the properties of the several boolean operators, -- except conjuntion. It does not use equality on expressions, only -- on variables. truthTable :: Eq id => ICond id -> ICond id truthTable (IBoolOp op (IBool b1) (IBool b2)) = IBool $ mapIBOp op b1 b2 truthTable (IBoolOp IOr (IBool b1) b2) | b1 = IBool True | otherwise = truthTable b2 truthTable (IBoolOp IOr b1 (IBool b2)) | b2 = IBool True | otherwise = truthTable b1 truthTable (IBoolOp IXor (IBool b1) b2) | b1 = truthTable $ deMorgan b2 | otherwise = truthTable b2 truthTable (IBoolOp IXor b1 (IBool b2)) | b2 = truthTable $ deMorgan b1 | otherwise = truthTable b1 truthTable (IBoolOp IXor (IBInd b1) (IBInd b2)) | b1 == b2 = IBool False | otherwise = IBool True truthTable (IBoolOp _ (IBInd b1) (IBInd b2)) -- Idempotence | b1 == b2 = IBInd b1 truthTable e = e -------------------------------------------------------------------------------- -- Application of deMorgan rules deMorgan :: ICond id -> ICond id deMorgan (IBool b) = IBool $ not b deMorgan (INot c) = c deMorgan (IBoolOp IOr c1 c2) = IAnd [deMorgan c1, deMorgan c2] deMorgan (IBoolOp IXor c1 c2) = aux c1 c2 where aux (IBool b) e = IBoolOp IXor (IBool $ not b) e aux e (IBool b) = IBoolOp IXor e (IBool $ not b) aux (INot b) e = IBoolOp IXor b e aux e (INot b) = IBoolOp IXor e b aux i@(IBInd _) e = IBoolOp IXor (INot i) e aux e i@(IBInd _) = IBoolOp IXor e (INot i) aux e1 e2 = IBoolOp IXor (deMorgan e1) e2 deMorgan (IAnd lc) = andToOr $ map deMorgan lc where andToOr [] = error ": unexpected case" andToOr [x] = x andToOr (x:xs) = IBoolOp IOr x (andToOr xs) deMorgan i = INot i -------------------------------------------------------------------------------- evalCond :: (Ord id, Eq id) => ICond id -> ICond id evalCond (INot c) = case deMorgan (evalCond c) of IAnd l -> flatAnd l c' -> truthTable c' evalCond (IAnd l) = flatAnd $ map evalCond l -- Canonical form for Nested expressions evalCond (IBoolOp op c1 c2) = case (evalCond c1, evalCond c2) of (l1@(IAnd _), l2) -> flatAnd $ distrOr l1 l2 (l1, l2@(IAnd _)) -> flatAnd $ distrOr l1 l2 (l1, l2) -> truthTable $ IBoolOp op l1 l2 evalCond (ILeq e) = case evalExpr e of IInt i -> IBool $ 0 <= i e' -> ILeq e' evalCond (IEq e) = case evalExpr e of IInt i -> IBool $ 0 == i e' -> IEq e' evalCond c = c -------------------------------------------------------------------------------- distrOr :: ICond id -> ICond id -> [ICond id] distrOr (IAnd l1) (IAnd l2) = concatMap (distrOr' l1) l2 distrOr c (IAnd l2) = distrOr' l2 c distrOr (IAnd l1) c = distrOr' l1 c distrOr _ _ = error ": not expected" distrOr' :: [ICond id] -> ICond id -> [ICond id] distrOr' l c = map (IBoolOp IOr c) l -------------------------------------------------------------------------------- -- Remove True -- Reduce to False -- Bring out nested And flatAnd :: Eq id => [ICond id] -> ICond id flatAnd c = let (v, var, i) = foldr aux (True, [], []) c in if v && not (null i && null var) then IAnd (nub var ++ i) else IBool v where aux (IBool False) _ = (False, [], []) aux (IBool True) r = r aux (IAnd l) (v, vs, r) = case flatAnd l of IAnd l' -> (v, vs, l' ++ r) IBool False -> (False, [], []) IBool True -> (v, vs, r) _ -> error "flatAnd.aux: Not expected case" aux i@(IBInd _) (v, vs, r) = (v, i : vs, r) aux x (v, vs, r) = case truthTable x of IBool False -> (False, [], []) IBool True -> (v, vs, r) x' -> (v, vs, x' : r) -------------------------------------------------------------------------------- -------------------------------------------------------------------------------- -- Partially inspired by -- Producing Proofs from an Arithmetic Decision Procedure in Elliptical LF -- Aaron Stump, Clark W. Barrett, and David L. Dill {- Flat form: 1. All sums are non-empty 2. The first element of a sum is always a constant 3. Variables are never alone. They are always part of a product 1 * v 4. Symmetric is moved downwards to values and variables 5. There are not nested sums 6. The outer symbol of a flatExpr is always a ISum 7. Operations on literals are always computed -} -------------------------------------------------------------------------------- evalExpr :: (Eq id, Ord id) => IExpr id -> IExpr id evalExpr = canonicalExpr . flatExpr flatExpr :: Eq id => IExpr id -> IExpr id flatExpr i@(IInt _) = ISum [i] flatExpr i@(IInd _) = ISum [IInt 0, IArith ITimes (IInt 1) i] flatExpr (ISym (IInt n)) = resInt $ negate n -- shortcut flatExpr (ISym t) = distrSym $ flatExpr t flatExpr (ISum l) = ISum $ flatSum $ map flatExpr l flatExpr (IArith IMinus (IInt n) (IInt n')) = resInt $ n - n' flatExpr (IArith IMinus (IInd i) (IInd i')) | i == i' = resInt 0 flatExpr (IArith IMinus i1 i2) = ISum $ flatSum [flatExpr i1, distrSym $ flatExpr i2] flatExpr (IArith ITimes (IInt n) (IInt n')) = resInt $ n * n' flatExpr e@(IArith ITimes _ _) = flatTimes e flatExpr (IArith IDiv (IInt n) (IInt n')) = resInt $ div n n' flatExpr (IArith IDiv p q) = let ISum p' = flatExpr p ISum q' = flatExpr q in ISum [listDiv p' q'] flatExpr (IArith IPower (IInt b) (IInt e)) = resInt $ b^e flatExpr (IArith IPower b e) = ISum [IInt 0, IArith ITimes (IInt 1) (IArith IPower (flatExpr b) (flatExpr e))] flatExpr (IArith IModOp (IInt a) (IInt b)) = resInt $ mod a b flatExpr _ = error "TODO: flatExpr: " resInt :: Integer -> IExpr id resInt n = ISum [IInt n] flatTimes :: Eq id => IExpr id -> IExpr id flatTimes e@(IArith ITimes _ _) = let (ci, si, mi) = sepTimes e mm = toMult (product ci) mi (sc, ssi) = iTimesConcat $ map flatExpr si (pc, pe) = iTimesLst [mm] ssi (pc', pe') = iTimesLst [mm] sc sumCi = constSum $ pc ++ pc' in ISum $ sumCi : pe ++ pe' flatTimes _ = error ": not expected" toMult :: Integer -> [IExpr id] -> IExpr id toMult n [] = IInt n toMult n xs@(_:_) = IArith ITimes (IInt n) (aux xs) where aux [] = error ": not expected" aux [e] = e aux (e:es) = IArith ITimes e (aux es) sepTimes :: Eq id => IExpr id -> ([Integer], [IExpr id], [IExpr id]) sepTimes (IInt n) = ([n], [], []) sepTimes i@(IInd _) = ([], [], [i]) sepTimes s@(ISum _) = ([], [s], []) sepTimes (IArith ITimes i1 i2) = let (ci1, si1, mi1) = sepTimes i1 (ci2, si2, mi2) = sepTimes i2 in (ci1 ++ ci2, si1 ++ si2, mi1 ++ mi2) sepTimes (IArith IPower (IInt n) (IInt e)) = ([n ^ e], [], []) sepTimes e@(IArith IMinus _ _) = sepTimes (flatExpr e) sepTimes (ISym e) = sepTimes (flatExpr e) sepTimes s@(IArith IDiv _ _) = ([], [s], []) sepTimes _ = error "<> IExpr id constSum = constOp ((+), 0) constOp :: (Integer -> Integer -> Integer, Integer) -> [IExpr id] -> IExpr id constOp (f, n) = foldr aux (IInt n) where aux (IInt m) (IInt res) = IInt (f m res) aux _ _ = error ": not expected" iTimesConcat :: [IExpr id] -> ([IExpr id], [IExpr id]) iTimesConcat [] = ([IInt 0],[IInt 1]) iTimesConcat [ISum x] = ([head x], tail x) iTimesConcat (ISum x:xs) = let (cs, xs') = iTimesConcat xs (c, i) = iTimesLst x cs (c', i') = iTimesLst x xs' in ([constSum $ c ++ c'], i ++ i') iTimesConcat (_:_) = error ": not expected" iTimesLst :: [IExpr id] -> [IExpr id] -> ([IExpr id], [IExpr id]) iTimesLst [] _ = ([], []) iTimesLst [x] xr = let (nl, ol) = unzip $ map (iTimes x) xr in ([constSum $ concat nl], concat ol) iTimesLst (x:xl) xr = let (nl, ol) = iTimesLst xl xr (nl', ol') = unzip $ map (iTimes x) xr in ([constSum (nl ++ concat nl')], ol ++ concat ol') iTimes :: IExpr id -> IExpr id -> ([IExpr id], [IExpr id]) -- Constant * Constant iTimes (IInt n) (IInt n') = ([IInt $ n * n'], []) -- Constant * Variable iTimes (IInt 0) (IInd _) = ([], []) iTimes (IInt n) (IInd i) = ([], [IArith ITimes (IInt n) (IInd i)]) iTimes (IInd _) (IInt 0) = ([], []) iTimes (IInd i) (IInt n) = ([], [IArith ITimes (IInt n) (IInd i)]) -- Constant * Product iTimes (IInt 0) (IArith ITimes (IInt _) _) = ([], []) iTimes (IInt n) (IArith ITimes (IInt c) i) = ([], [IArith ITimes (IInt $ c * n) i]) iTimes (IArith ITimes (IInt _) _) (IInt 0) = ([], []) iTimes (IArith ITimes (IInt c) i) (IInt n) = ([], [IArith ITimes (IInt $ c * n) i]) -- Variable * Product iTimes (IInd i) (IArith ITimes (IInt c) i') = ([], [IArith ITimes (IInt c) (IArith ITimes (IInd i) i')]) iTimes (IArith ITimes (IInt c) i') (IInd i) = ([], [IArith ITimes (IInt c) (IArith ITimes (IInd i) i')]) -- Product * Product iTimes (IArith ITimes (IInt c) i) (IArith ITimes (IInt c') i') = ([], [IArith ITimes (IInt $ c * c') (IArith ITimes i i')]) -- TODO: Not in the right form iTimes (IArith ITimes i1 i2) e2@(IArith ITimes _ _) = ([], [IArith ITimes (IInt 1) $ IArith ITimes i1 (IArith ITimes i2 e2)]) -- Produce * Division iTimes l@(IArith IDiv _ _) r@(IInt _) = ([], [IArith ITimes l r]) iTimes l@(IInt _) r@(IArith IDiv _ _) = ([], [IArith ITimes l r]) -- Error iTimes _ _ = error ": not expected" -------------------------------------------------------------------------------- -- Expectes a flat expression distrSym :: IExpr id -> IExpr id distrSym e = case e of IInt n -> IInt (negate n) ISym i -> i ISum l -> ISum $ map distrSym l -- always the entry point IArith ITimes (IInt c) i -> IArith ITimes (IInt (negate c)) i IArith ITimes c _ -> IArith ITimes (distrSym c) e IArith IDiv (IInt c) i -> IArith IDiv (IInt (negate c)) i IArith IDiv c (IInt i) -> IArith IDiv c (IInt (negate i)) IArith IDiv c i -> IArith IDiv (distrSym c) i IArith IModOp _ _ -> error ": <> mod" IArith IPower _ _ -> error ": should never reach a power" IInd _ -> error ": should never reach a index variable" _ -> error "<>: missing case: " -------------------------------------------------------------------------------- listDiv :: [IExpr id] -> [IExpr id] -> IExpr id listDiv [IInt l] [IInt r] = IInt $ mapIAOp IDiv l r listDiv [IInt l] [IInd r] = IArith IDiv (IInt l) (IInd r) listDiv [IInd l] [IInt r] = IArith IDiv (IInd l) (IInt r) listDiv [IInd l] [IInd r] = IArith IDiv (IInd l) (IInd r) listDiv l r = IArith IDiv (ISum l) (ISum r) -------------------------------------------------------------------------------- -- This may not be enough to bring them to the top level flatSum :: [IExpr id] -> [IExpr id] flatSum l = let (c, l') = aux l in IInt (sum c) : concat l' where aux :: [IExpr id] -> ([Integer], [[IExpr id]]) aux [] = ([], []) aux (ISum (IInt n:l'):ls) = let (ns, ls') = aux ls in (n : ns, l' : ls') aux (ISum l':ls) = let (ns, ls') = aux ls in (ns, l' : ls') aux (x:ls) = let (ns, ls') = aux ls in (ns, [x]:ls') -------------------------------------------------------------------------------- cmp :: Ord id => IExpr id -> IExpr id -> Ordering cmp (IInt _) _ = LT cmp _ (IInt _) = GT cmp (IArith ITimes (IInt _) (IInd i)) (IArith ITimes (IInt _) (IInd i')) = compare i i' cmp (IArith op _ _) (IArith op' _ _) = cmpIAOp op op' cmp (ISum _) (IArith {}) = LT cmp (IArith {}) (ISum _) = GT cmp (ISum l) (ISum l') = cmpList l l' cmp _ _ = error "Ordering: not expected: " -- Lexicographic order cmpList :: Ord id => [IExpr id] -> [IExpr id] -> Ordering cmpList [] [] = EQ cmpList [] _ = LT cmpList _ [] = GT cmpList (x:xs) (x':xs') = case cmp x x' of EQ -> cmpList xs xs' r -> r cmpIAOp :: IAOp -> IAOp -> Ordering cmpIAOp ITimes _ = LT cmpIAOp _ ITimes = GT cmpIAOp IDiv _ = LT cmpIAOp _ IDiv = GT cmpIAOp IPower _ = LT cmpIAOp _ IPower = GT cmpIAOp IModOp _ = GT cmpIAOp _ _ = error ": not expected" -- TODO: non-linear coeficients may need this as well canonicalExpr :: Ord id => IExpr id -> IExpr id canonicalExpr (ISum l) = revert $ ISum $ combine $ sortBy cmp l where combine :: Eq id => [IExpr id] -> [IExpr id] combine [] = [] combine [i] = [i] combine (e1@(IArith ITimes (IInt n) (IInd i)): e2@(IArith ITimes (IInt n') (IInd i')) : xs) = let r = n + n' in if i == i' then if r /= 0 then combine (IArith ITimes (IInt r) (IInd i) : xs) else combine xs else e1 : combine (e2 : xs) combine (x:xs) = x : combine xs revert :: IExpr id -> IExpr id revert (ISum [i@(IInt _)]) = i revert (ISum [IInt 0, IArith ITimes (IInt 1) v]) = v revert (ISum (IInt 0 : xs)) = ISum $ concatMap aux xs revert (ISum lst) = ISum $ concatMap aux lst revert lst = lst aux (IArith ITimes (IInt 1) e) = [e] aux (IArith ITimes (IInt 0) _) = [] aux e = [e] canonicalExpr e = e -------------------------------------------------------------------------------- mapIBOp :: IBOp -> Bool -> Bool -> Bool mapIBOp IOr = boolOr mapIBOp IXor = boolXor -------------------------------------------------------------------------------- mapIAOp :: IAOp -> Integer -> Integer -> Integer mapIAOp IMinus = integerMinus mapIAOp ITimes = integerTimes mapIAOp IPower = integerPower mapIAOp IDiv = integerDiv mapIAOp IModOp = integerMod