-- | -- Module : Language.Sally.Expr -- Description : Smart constructors for Sally AST types -- Copyright : Benjamin Jones 2016-2017 -- License : BSD3 -- -- Maintainer : bjones@galois.com -- Stability : experimental -- Portability : unknown -- -- Better constructors for Sally expresssions and predicates than the raw ones -- defined in "Language.Sally.Types". -- {-# LANGUAGE ViewPatterns #-} module Language.Sally.Expr ( -- * better constructors boolExpr , boolPred , intExpr , zeroExpr , oneExpr , realExpr , addExpr , subExpr , multExpr , divExpr , notExpr , eqExpr , neqExpr , ltExpr , leqExpr , gtExpr , geqExpr , muxExpr , andExprs , andPreds , orExprs , orPreds , varExpr , varExpr' , xorExpr -- * complex expression builders , minExpr , countExpr -- * expression rewriting , constFold , simplifyAnds , simplifyOrs , flattenAnds , flattenOrs ) where import Data.Sequence (Seq, (<|), (><), viewl, ViewL(..)) import qualified Data.Sequence as Seq import Language.Sally.Types -- Better Constructors --------------------------------------------------------- -- | Create a constant boolean expression. boolExpr :: Bool -> SallyExpr boolExpr = SELit . SConstBool -- | Create a constant boolean predicate. boolPred :: Bool -> SallyPred boolPred = SPConst -- | Create a constant integer expression. intExpr :: Integral a => a -> SallyExpr intExpr = SELit . SConstInt . fromIntegral -- | Create an expression for zero as an integer in Sally. zeroExpr :: SallyExpr zeroExpr = intExpr (0 :: Int) -- | Create an expression for one as an integer in Sally. oneExpr :: SallyExpr oneExpr = intExpr (1 :: Int) -- | Create a constant real expression. realExpr :: Real a => a -> SallyExpr realExpr = SELit . SConstReal . toRational -- | Better constructor for adding expressions -- TODO maintain normal form addExpr :: SallyExpr -> SallyExpr -> SallyExpr addExpr x y = SEArith (SAAdd x y) -- | Subtract two 'SallyExpr'. subExpr :: SallyExpr -> SallyExpr -> SallyExpr subExpr x y = SEArith (SAAdd x ny) where ny = multExpr (SELit (SConstInt (-1))) y -- | Better constructor for multiplying expressions; checks that one of the -- operands is a constant. multExpr :: SallyExpr -> SallyExpr -> SallyExpr multExpr x y = if isMultConst x || isMultConst y then SEArith (SAMult x y) else error "multExpr: non-linear arithmetic is not supported" -- | Better constructor for dividing expressions; checks that the denominator -- is a constant. divExpr :: SallyExpr -> SallyExpr -> SallyExpr divExpr x y = if isMultConst y then SEArith (SADiv x y) else error "multExpr: non-linear arithmetic is not supported" -- | Determine if a Sally expression is a constant for the purposes of linear -- multiplication. Note: this is an over approximation, e.g. @(x + (-x))*y@ -- is a constant equal to @0@ times @y@, but will not pass this predicate. isMultConst :: SallyExpr -> Bool isMultConst (SELit _) = True isMultConst (SEVar _) = False isMultConst (SEPre _) = False isMultConst (SEArith (SAAdd x y)) = isMultConst x && isMultConst y isMultConst (SEArith (SAMult x y)) = isMultConst x && isMultConst y isMultConst (SEArith (SAExpr _)) = False isMultConst SEMux{} = False -- | Create the expression equating two given expressions. eqExpr :: SallyExpr -> SallyExpr -> SallyExpr eqExpr x y = SEPre (SPEq x y) -- | @a \`ltExpr\` b@ represents the expression @a \< b@. ltExpr :: SallyExpr -> SallyExpr -> SallyExpr ltExpr x y = SEPre (SPLt x y) -- | @a \`leqExpr\` b@ represents the expression @a \<= b@. leqExpr :: SallyExpr -> SallyExpr -> SallyExpr leqExpr x y = SEPre (SPLEq x y) -- | @a \`gtExpr\` b@ represents the expression @a > b@. gtExpr :: SallyExpr -> SallyExpr -> SallyExpr gtExpr x y = SEPre (SPGt x y) -- | @a \`geqExpr\` b@ represents the expression @a >= b@. geqExpr :: SallyExpr -> SallyExpr -> SallyExpr geqExpr x y = SEPre (SPGEq x y) -- | Create the expression that is the boolean negation of the given one. notExpr :: SallyExpr -> SallyExpr notExpr x = SEPre (SPNot (getPred x)) -- | Create the XOR of two Sally expressions. xorExpr :: SallyExpr -> SallyExpr -> SallyExpr xorExpr x y = andExprs [orExprs [x, y], notExpr (andExprs [x, y])] -- | Create the expression representing non-equality. neqExpr :: SallyExpr -> SallyExpr -> SallyExpr neqExpr x y = notExpr (eqExpr x y) -- | Turn a SallyExpr into a SallyPred (if possible) getPred :: SallyExpr -> SallyPred getPred x = case x of SEPre w -> w SELit{} -> SPExpr x SEVar{} -> SPExpr x SEMux{} -> SPExpr x SEArith{} -> error ("notExpr: cannot turn expression into predicate: " ++ show x) -- | Create an if-then-else expression: @mux b x y@ represents the statement -- @if b then x else y@. muxExpr :: SallyExpr -> SallyExpr -> SallyExpr -> SallyExpr muxExpr = SEMux -- | Form the conjunction of the given expressions (which should be -- predicates, but this is not checked). andExprs :: [SallyExpr] -> SallyExpr andExprs es = SEPre $ andPreds (fmap getPred es) -- | And over multiple predicates, doing some small inline simplification andPreds :: [SallyPred] -> SallyPred andPreds [] = SPConst True -- intersection over no sets is the whole universe andPreds [p] = p andPreds ps = SPAnd . flattenAnds . Seq.fromList $ ps -- | Form the disjunction of the given expressions (which should be -- predicates, but this is not checked). orExprs :: [SallyExpr] -> SallyExpr orExprs es = SEPre $ orPreds (fmap getPred es) -- | Or over multiple predicates, doing some small inline simplification orPreds :: [SallyPred] -> SallyPred orPreds [] = SPConst False -- union over no sets is the empty set orPreds [p] = p orPreds ps = SPOr . flattenOrs . Seq.fromList $ ps -- | Create a variable expression. varExpr :: SallyVar -> SallyExpr varExpr = SEVar -- | Create a variable expression from a name. varExpr' :: Name -> SallyExpr varExpr' = SEVar . varFromName -- More Complicated expression builders ---------------------------------------- -- | Given a non-empty finite list of expressions, build an expression to -- compute their minimum. The second argument is a special value which, if -- present causes expressions in the list with this value to be ignored in the -- calculation. If the input list contains only the special value, then the -- special value itself is returned. minExpr :: [SallyExpr] -> Maybe SallyExpr -> SallyExpr minExpr [] _ = error "minExpr: cannot apply minExpr to empty list" minExpr (x:rest) sp' = go sp' x rest where go _ m [] = m go Nothing m (y:more) = muxExpr (ltExpr m y) (go sp' m more) (go sp' y more) go (Just sp) m (y:more) = muxExpr (andExprs [ltExpr m y, neqExpr m sp]) (go sp' m more) (go sp' y more) -- | Build a Sally expression representing the number of times a particular -- item appears in a list of expressions. countExpr :: SallyExpr -> [SallyExpr] -> SallyExpr countExpr _ [] = zeroExpr countExpr x (y:rest) = muxExpr (eqExpr x y) (addExpr oneExpr (countExpr x rest)) (countExpr x rest) -- Expression Rewriting -------------------------------------------------------- -- | A basic top-down recursive constant folding function. constFold :: SallyExpr -> SallyExpr constFold = simplifyExpr . constFold' where constFold' e@(SELit _) = e constFold' e@(SEVar _) = e constFold' (SEPre p) = SEPre (constFoldP p) constFold' (SEArith a) = SEArith (constFoldA a) constFold' (SEMux i t e) = constFoldM i t e -- | Perform constant folding over a Sally predicate. constFoldP :: SallyPred -> SallyPred constFoldP = simplifyOrs . simplifyAnds -- | Perform constant folding over a Sally arithmetic expression. constFoldA :: SallyArith -> SallyArith -- additive folding -- add zero constFoldA (SAAdd (SELit (SConstInt 0)) e) = SAExpr (constFold e) constFoldA (SAAdd e (SELit (SConstInt 0))) = SAExpr (constFold e) constFoldA (SAAdd (SELit (SConstReal 0)) e) = SAExpr (constFold e) constFoldA (SAAdd e (SELit (SConstReal 0))) = SAExpr (constFold e) -- add two constant literals constFoldA (SAAdd (SELit (SConstInt x)) (SELit (SConstInt y))) = SAExpr (SELit (SConstInt (x+y))) constFoldA (SAAdd (SELit (SConstReal x)) (SELit (SConstReal y))) = SAExpr (SELit (SConstReal (x+y))) -- additive fall through case constFoldA a@(SAAdd _ _) = a -- multiplicitive folding: -- mult by 1 constFoldA (SAMult (SELit (SConstInt 1)) e) = SAExpr (constFold e) constFoldA (SAMult e (SELit (SConstInt 1))) = SAExpr (constFold e) constFoldA (SAMult (SELit (SConstReal 1)) e) = SAExpr (constFold e) constFoldA (SAMult e (SELit (SConstReal 1))) = SAExpr (constFold e) -- mult by 0 constFoldA (SAMult (SELit (SConstInt 0)) _) = SAExpr zeroExpr constFoldA (SAMult _ (SELit (SConstInt 0))) = SAExpr zeroExpr constFoldA (SAMult (SELit (SConstReal 0)) _) = SAExpr zeroExpr constFoldA (SAMult _ (SELit (SConstReal 0))) = SAExpr zeroExpr -- mult two constant literals constFoldA (SAMult (SELit (SConstInt x)) (SELit (SConstInt y))) = SAExpr (SELit (SConstInt (x*y))) constFoldA (SAMult (SELit (SConstReal x)) (SELit (SConstReal y))) = SAExpr (SELit (SConstReal (x*y))) -- fall through general case constFoldA a@(SAMult _ _) = a constFoldA (SAExpr e) = SAExpr (constFold e) -- | Constant fold a mux expression constFoldM :: SallyExpr -> SallyExpr -> SallyExpr -> SallyExpr constFoldM (SELit (SConstBool True)) t _ = constFold t constFoldM (SELit (SConstBool False)) _ f = constFold f constFoldM i t e = SEMux i (constFold t) (constFold e) -- | Recursively flatten a tree of @and@ expressions into an @and@ sequence. flattenAnds :: Seq SallyPred -> Seq SallyPred flattenAnds (viewl -> xs) = case xs of EmptyL -> Seq.empty a :< rest -> case a of SPAnd ys -> flattenAnds ys >< flattenAnds rest -- TODO enable rewriting here? -- SPConst True -> flattenAnds rest -- SPConst False -> a <| Seq.empty _ -> a <| flattenAnds rest -- | Recursively flatten a tree of @or@ expressions into an @or@ sequence. flattenOrs :: Seq SallyPred -> Seq SallyPred flattenOrs (viewl -> EmptyL) = Seq.empty flattenOrs (viewl -> a :< rest) = case a of SPOr ys -> flattenOrs ys >< flattenOrs rest _ -> a <| flattenOrs rest flattenOrs _ = undefined -- make compiler happy :) -- | Top-down rewriting of conjunctions of terms including constant folding and -- constructor reduction. simplifyAnds :: SallyPred -> SallyPred simplifyAnds p = case p of -- main case SPAnd xs -> let ys = flattenAnds (fmap simplifyAnds xs) :: Seq SallyPred in case viewl ys of EmptyL -> SPConst True -- empty 'and' z :< zs -> if Seq.null zs then z -- single elt. 'and' else SPAnd ys -- multiple SPExpr (SEPre q) -> simplifyAnds q -- strip off SPExpr . SEPre -- other cases SPConst _ -> p SPOr xs -> SPOr (fmap simplifyAnds xs) SPImpl x y -> SPImpl (simplifyAnds x) (simplifyAnds y) SPNot x -> SPNot (simplifyAnds x) SPEq x y -> SPEq (constFold x) (constFold y) SPLEq x y -> SPLEq (constFold x) (constFold y) SPGEq x y -> SPGEq (constFold x) (constFold y) SPLt x y -> SPLt (constFold x) (constFold y) SPGt x y -> SPGt (constFold x) (constFold y) SPExpr e -> SPExpr (constFold e) -- | Top-down rewriting of disjunctions including constant folding and -- constructor reduction. simplifyOrs :: SallyPred -> SallyPred simplifyOrs p = case p of -- main case SPOr xs -> let ys = flattenOrs (fmap simplifyOrs xs) in case viewl ys of EmptyL -> SPConst False -- empty disjunction z :< zs -> if Seq.null zs then z -- single term else SPOr ys -- multiple terms SPExpr (SEPre q) -> simplifyOrs q -- strip off SPExpr . SEPre -- other cases SPConst _ -> p SPAnd xs -> SPAnd (fmap simplifyOrs xs) SPImpl x y -> SPImpl (simplifyOrs x) (simplifyOrs y) SPNot x -> SPNot (simplifyOrs x) SPEq x y -> SPEq (constFold x) (constFold y) SPLEq x y -> SPLEq (constFold x) (constFold y) SPGEq x y -> SPGEq (constFold x) (constFold y) SPLt x y -> SPLt (constFold x) (constFold y) SPGt x y -> SPGt (constFold x) (constFold y) SPExpr e -> SPExpr (constFold e) -- | Reduce SallyExpr terms by removing redundant constructors. simplifyExpr :: SallyExpr -> SallyExpr simplifyExpr (SEArith (SAExpr e)) = simplifyExpr e simplifyExpr (SEPre (SPExpr e)) = simplifyExpr e simplifyExpr e = e