module Language.SMTLib2.Internals.Optimize (optimizeBackend,optimizeExpr) where import Language.SMTLib2.Internals import Language.SMTLib2.Internals.Instances (bvSigned,bvUnsigned,bvRestrict,eqExpr) import Language.SMTLib2.Internals.Operators import Data.Proxy import Data.Bits import Data.Either (partitionEithers) import Data.Typeable (cast) optimizeBackend :: b -> OptimizeBackend b optimizeBackend = OptB data OptimizeBackend b = OptB b instance SMTBackend b m => SMTBackend (OptimizeBackend b) m where smtHandle (OptB b) (SMTAssert expr grp cid) = let nexpr = case optimizeExpr expr of Just e -> e Nothing -> expr in case nexpr of Const True _ -> return ((),OptB b) _ -> do (res,nb) <- smtHandle b (SMTAssert nexpr grp cid) return (res,OptB nb) smtHandle (OptB b) (SMTDefineFun name prx ann body) = do let nbody = case optimizeExpr body of Just e -> e Nothing -> body (res,nb) <- smtHandle b (SMTDefineFun name prx ann nbody) return (res,OptB nb) smtHandle (OptB b) (SMTGetValue expr) = do let nexpr = case optimizeExpr expr of Just e -> e Nothing -> expr (res,nb) <- smtHandle b (SMTGetValue nexpr) return (res,OptB nb) smtHandle (OptB b) SMTGetProof = do (res,nb) <- smtHandle b SMTGetProof return (case optimizeExpr res of Just e -> e Nothing -> res,OptB nb) smtHandle (OptB b) (SMTSimplify expr) = do let nexpr = case optimizeExpr expr of Just e -> e Nothing -> expr (simp,nb) <- smtHandle b (SMTSimplify nexpr) return (case optimizeExpr simp of Nothing -> simp Just simp' -> simp',OptB nb) smtHandle (OptB b) (SMTGetInterpolant grps) = do (inter,nb) <- smtHandle b (SMTGetInterpolant grps) return (case optimizeExpr inter of Nothing -> inter Just e -> e,OptB nb) smtHandle (OptB b) req = do (res,nb) <- smtHandle b req return (res,OptB nb) smtGetNames (OptB b) = smtGetNames b smtNextName (OptB b) = smtNextName b optimizeExpr :: SMTExpr t -> Maybe (SMTExpr t) optimizeExpr (App fun x) = let (opt,x') = foldExprsId (\opt expr ann -> case optimizeExpr expr of Nothing -> (opt,expr) Just expr' -> (True,expr') ) False x (extractArgAnnotation x) in case optimizeCall fun x' of Nothing -> if opt then Just $ App fun x' else Nothing Just res -> Just res optimizeExpr _ = Nothing optimizeCall :: SMTFunction arg res -> arg -> Maybe (SMTExpr res) optimizeCall SMTEq [] = Just (Const True ()) optimizeCall SMTEq [_] = Just (Const True ()) optimizeCall SMTEq [x,y] = case eqExpr x y of Nothing -> Nothing Just res -> Just (Const res ()) optimizeCall SMTNot (Const x _) = Just $ Const (not x) () optimizeCall (SMTLogic _) [x] = Just x optimizeCall (SMTLogic And) xs = case removeConstsOf False xs of Just _ -> Just $ Const False () Nothing -> case removeConstsOf True xs of Nothing -> case xs of [] -> Just $ Const True () _ -> Nothing Just [] -> Just $ Const True () Just [x] -> Just x Just xs' -> Just $ App (SMTLogic And) xs' optimizeCall (SMTLogic Or) xs = case removeConstsOf True xs of Just _ -> Just $ Const True () Nothing -> case removeConstsOf False xs of Nothing -> case xs of [] -> Just $ Const False () _ -> Nothing Just [] -> Just $ Const False () Just [x] -> Just x Just xs' -> Just $ App (SMTLogic Or) xs' optimizeCall (SMTLogic XOr) [] = Just $ Const False () optimizeCall (SMTLogic Implies) [] = Just $ Const True () optimizeCall (SMTLogic Implies) xs = let (args,res) = splitLast xs in case res of Const True _ -> Just (Const True ()) _ -> case removeConstsOf False args of Just _ -> Just $ Const True () Nothing -> case removeConstsOf True args of Nothing -> case args of [] -> Just res _ -> Nothing Just [] -> Just res Just args' -> Just $ App (SMTLogic Implies) (args'++[res]) optimizeCall SMTITE (Const True _,ifT,_) = Just ifT optimizeCall SMTITE (Const False _,_,ifF) = Just ifF optimizeCall SMTITE (_,ifT,ifF) = case eqExpr ifT ifF of Just True -> Just ifT _ -> Nothing optimizeCall (SMTBVBin op) args = bvBinOpOptimize op args optimizeCall SMTConcat (Const (BitVector v1::BitVector b1) ann1,Const (BitVector v2::BitVector b2) ann2) = Just (Const (BitVector $ (v1 `shiftL` (fromInteger $ getBVSize (Proxy::Proxy b2) ann2)) .|. v2) (concatAnnotation (undefined::b1) (undefined::b2) ann1 ann2)) optimizeCall (SMTExtract pstart plen) (Const from@(BitVector v) ann) = let start = reflectNat pstart 0 undefFrom :: BitVector from -> from undefFrom _ = undefined undefLen :: SMTExpr (BitVector len) -> len undefLen _ = undefined len = reflectNat plen 0 res = Const (BitVector $ (v `shiftR` (fromInteger start)) .&. (1 `shiftL` (fromInteger $ reflectNat plen 0) - 1)) (extractAnn (undefFrom from) (undefLen res) len ann) in Just res optimizeCall (SMTBVComp op) args = bvCompOptimize op args optimizeCall (SMTArith op) args = case cast args of Just args' -> case cast (intArithOptimize op args') of Just res -> res Nothing -> Nothing optimizeCall SMTMinus args = case cast args of Just args' -> case cast (intMinusOptimize args') of Just res -> res Nothing -> Nothing optimizeCall (SMTOrd op) args = case cast args of Just args' -> case cast (intCmpOptimize op args') of Just res -> res Nothing -> Nothing optimizeCall _ _ = Nothing removeConstsOf :: Bool -> [SMTExpr Bool] -> Maybe [SMTExpr Bool] removeConstsOf val = removeItems (\e -> case e of Const c _ -> c==val _ -> False) removeItems :: (a -> Bool) -> [a] -> Maybe [a] removeItems f [] = Nothing removeItems f (x:xs) = if f x then (case removeItems f xs of Nothing -> Just xs Just xs' -> Just xs') else (case removeItems f xs of Nothing -> Nothing Just xs' -> Just (x:xs')) splitLast :: [a] -> ([a],a) splitLast [x] = ([],x) splitLast (x:xs) = let (xs',last) = splitLast xs in (x:xs',last) bvBinOpOptimize :: IsBitVector a => SMTBVBinOp -> (SMTExpr (BitVector a),SMTExpr (BitVector a)) -> Maybe (SMTExpr (BitVector a)) bvBinOpOptimize BVAdd (Const (BitVector 0) _,y) = Just y bvBinOpOptimize BVAdd (x,Const (BitVector 0) _) = Just x bvBinOpOptimize BVAdd (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (bvRestrict (BitVector $ x+y) w) w) bvBinOpOptimize BVAnd (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (BitVector $ x .&. y) w) bvBinOpOptimize BVOr (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (BitVector $ x .|. y) w) bvBinOpOptimize BVOr (Const (BitVector 0) _,oth) = Just oth bvBinOpOptimize BVOr (oth,Const (BitVector 0) _) = Just oth bvBinOpOptimize BVSHL (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (bvRestrict (BitVector $ x `shiftL` (fromInteger y)) w) w) bvBinOpOptimize BVSHL (Const (BitVector 0) w,_) = Just (Const (BitVector 0) w) bvBinOpOptimize BVSHL (oth,Const (BitVector 0) w) = Just oth bvBinOpOptimize _ _ = Nothing bvCompOptimize :: IsBitVector a => SMTBVCompOp -> (SMTExpr (BitVector a),SMTExpr (BitVector a)) -> Maybe (SMTExpr Bool) bvCompOptimize op (Const b1 ann1,Const b2 ann2) = Just $ Const (case op of BVULE -> u1 <= u2 BVULT -> u1 < u2 BVUGE -> u1 >= u2 BVUGT -> u1 > u2 BVSLE -> s1 <= s2 BVSLT -> s1 < s2 BVSGE -> s1 >= s2 BVSGT -> s1 > s2) () where u1 = bvUnsigned b1 ann1 u2 = bvUnsigned b2 ann2 s1 = bvSigned b1 ann1 s2 = bvSigned b2 ann2 bvCompOptimize _ _ = Nothing intArithOptimize :: SMTArithOp -> [SMTExpr Integer] -> Maybe (SMTExpr Integer) intArithOptimize Plus xs = let (consts,nonconsts) = partitionEithers $ fmap (\e -> case e of Const i _ -> Left i _ -> Right e ) xs in case consts of [] -> Nothing [x] -> case nonconsts of [] -> Just (Const x ()) [y] -> if x==0 then Just y else Nothing _ -> Nothing _ -> let s = sum consts in case nonconsts of [] -> Just (Const s ()) [x] -> if s==0 then Just x else Just (App (SMTArith Plus) [x,Const s ()]) _ -> Just (App (SMTArith Plus) (nonconsts++(if s==0 then [] else [Const s ()]))) intArithOptimize Mult xs = let (consts,nonconsts) = partitionEithers $ fmap (\e -> case e of Const i _ -> Left i _ -> Right e ) xs in case consts of [] -> Nothing [_] -> Nothing _ -> case nonconsts of [] -> Just (Const (product consts) ()) _ -> Just (App (SMTArith Mult) (nonconsts++[Const (product consts) ()])) intMinusOptimize :: (SMTExpr Integer,SMTExpr Integer) -> Maybe (SMTExpr Integer) intMinusOptimize (Const x _,Const y _) = Just (Const (x-y) ()) intMinusOptimize (x,Const 0 _) = Just x intMinusOptimize _ = Nothing intCmpOptimize :: SMTOrdOp -> (SMTExpr Integer,SMTExpr Integer) -> Maybe (SMTExpr Bool) intCmpOptimize op (Const x _,Const y _) = Just (Const (case op of Ge -> x >= y Gt -> x > y Le -> x <= y Lt -> x < y) ()) intCmpOptimize _ _ = Nothing