module Language.SMTLib2.Internals.Optimize (optimizeBackend,optimizeExpr) where
import Language.SMTLib2.Internals
import Language.SMTLib2.Internals.Instances (bvSigned,bvUnsigned,bvRestrict,eqExpr)
import Language.SMTLib2.Internals.Operators
import Data.Proxy
import Data.Bits
import Data.Either (partitionEithers)
import Data.Typeable (cast)
optimizeBackend :: b -> OptimizeBackend b
optimizeBackend = OptB
data OptimizeBackend b = OptB b
instance SMTBackend b m => SMTBackend (OptimizeBackend b) m where
smtHandle (OptB b) (SMTAssert expr grp cid)
= let nexpr = case optimizeExpr expr of
Just e -> e
Nothing -> expr
in case nexpr of
Const True _ -> return ((),OptB b)
_ -> do
(res,nb) <- smtHandle b (SMTAssert nexpr grp cid)
return (res,OptB nb)
smtHandle (OptB b) (SMTDefineFun name prx ann body) = do
let nbody = case optimizeExpr body of
Just e -> e
Nothing -> body
(res,nb) <- smtHandle b (SMTDefineFun name prx ann nbody)
return (res,OptB nb)
smtHandle (OptB b) (SMTGetValue expr) = do
let nexpr = case optimizeExpr expr of
Just e -> e
Nothing -> expr
(res,nb) <- smtHandle b (SMTGetValue nexpr)
return (res,OptB nb)
smtHandle (OptB b) SMTGetProof = do
(res,nb) <- smtHandle b SMTGetProof
return (case optimizeExpr res of
Just e -> e
Nothing -> res,OptB nb)
smtHandle (OptB b) (SMTSimplify expr) = do
let nexpr = case optimizeExpr expr of
Just e -> e
Nothing -> expr
(simp,nb) <- smtHandle b (SMTSimplify nexpr)
return (case optimizeExpr simp of
Nothing -> simp
Just simp' -> simp',OptB nb)
smtHandle (OptB b) (SMTGetInterpolant grps) = do
(inter,nb) <- smtHandle b (SMTGetInterpolant grps)
return (case optimizeExpr inter of
Nothing -> inter
Just e -> e,OptB nb)
smtHandle (OptB b) req = do
(res,nb) <- smtHandle b req
return (res,OptB nb)
smtGetNames (OptB b) = smtGetNames b
smtNextName (OptB b) = smtNextName b
optimizeExpr :: SMTExpr t -> Maybe (SMTExpr t)
optimizeExpr (App fun x) = let (opt,x') = foldExprsId (\opt expr ann -> case optimizeExpr expr of
Nothing -> (opt,expr)
Just expr' -> (True,expr')
) False x (extractArgAnnotation x)
in case optimizeCall fun x' of
Nothing -> if opt
then Just $ App fun x'
else Nothing
Just res -> Just res
optimizeExpr _ = Nothing
optimizeCall :: SMTFunction arg res -> arg -> Maybe (SMTExpr res)
optimizeCall SMTEq [] = Just (Const True ())
optimizeCall SMTEq [_] = Just (Const True ())
optimizeCall SMTEq [x,y] = case eqExpr x y of
Nothing -> Nothing
Just res -> Just (Const res ())
optimizeCall SMTNot (Const x _) = Just $ Const (not x) ()
optimizeCall (SMTLogic _) [x] = Just x
optimizeCall (SMTLogic And) xs = case removeConstsOf False xs of
Just _ -> Just $ Const False ()
Nothing -> case removeConstsOf True xs of
Nothing -> case xs of
[] -> Just $ Const True ()
_ -> Nothing
Just [] -> Just $ Const True ()
Just [x] -> Just x
Just xs' -> Just $ App (SMTLogic And) xs'
optimizeCall (SMTLogic Or) xs = case removeConstsOf True xs of
Just _ -> Just $ Const True ()
Nothing -> case removeConstsOf False xs of
Nothing -> case xs of
[] -> Just $ Const False ()
_ -> Nothing
Just [] -> Just $ Const False ()
Just [x] -> Just x
Just xs' -> Just $ App (SMTLogic Or) xs'
optimizeCall (SMTLogic XOr) [] = Just $ Const False ()
optimizeCall (SMTLogic Implies) [] = Just $ Const True ()
optimizeCall (SMTLogic Implies) xs
= let (args,res) = splitLast xs
in case res of
Const True _ -> Just (Const True ())
_ -> case removeConstsOf False args of
Just _ -> Just $ Const True ()
Nothing -> case removeConstsOf True args of
Nothing -> case args of
[] -> Just res
_ -> Nothing
Just [] -> Just res
Just args' -> Just $ App (SMTLogic Implies) (args'++[res])
optimizeCall SMTITE (Const True _,ifT,_) = Just ifT
optimizeCall SMTITE (Const False _,_,ifF) = Just ifF
optimizeCall SMTITE (_,ifT,ifF) = case eqExpr ifT ifF of
Just True -> Just ifT
_ -> Nothing
optimizeCall (SMTBVBin op) args = bvBinOpOptimize op args
optimizeCall SMTConcat (Const (BitVector v1::BitVector b1) ann1,Const (BitVector v2::BitVector b2) ann2)
= Just (Const (BitVector $ (v1 `shiftL` (fromInteger $ getBVSize (Proxy::Proxy b2) ann2)) .|. v2)
(concatAnnotation (undefined::b1) (undefined::b2) ann1 ann2))
optimizeCall (SMTExtract pstart plen) (Const from@(BitVector v) ann)
= let start = reflectNat pstart 0
undefFrom :: BitVector from -> from
undefFrom _ = undefined
undefLen :: SMTExpr (BitVector len) -> len
undefLen _ = undefined
len = reflectNat plen 0
res = Const (BitVector $ (v `shiftR` (fromInteger start)) .&. (1 `shiftL` (fromInteger $ reflectNat plen 0) 1))
(extractAnn (undefFrom from) (undefLen res) len ann)
in Just res
optimizeCall (SMTBVComp op) args = bvCompOptimize op args
optimizeCall (SMTArith op) args = case cast args of
Just args' -> case cast (intArithOptimize op args') of
Just res -> res
Nothing -> Nothing
optimizeCall SMTMinus args = case cast args of
Just args' -> case cast (intMinusOptimize args') of
Just res -> res
Nothing -> Nothing
optimizeCall (SMTOrd op) args = case cast args of
Just args' -> case cast (intCmpOptimize op args') of
Just res -> res
Nothing -> Nothing
optimizeCall _ _ = Nothing
removeConstsOf :: Bool -> [SMTExpr Bool] -> Maybe [SMTExpr Bool]
removeConstsOf val = removeItems (\e -> case e of
Const c _ -> c==val
_ -> False)
removeItems :: (a -> Bool) -> [a] -> Maybe [a]
removeItems f [] = Nothing
removeItems f (x:xs) = if f x
then (case removeItems f xs of
Nothing -> Just xs
Just xs' -> Just xs')
else (case removeItems f xs of
Nothing -> Nothing
Just xs' -> Just (x:xs'))
splitLast :: [a] -> ([a],a)
splitLast [x] = ([],x)
splitLast (x:xs) = let (xs',last) = splitLast xs
in (x:xs',last)
bvBinOpOptimize :: IsBitVector a => SMTBVBinOp -> (SMTExpr (BitVector a),SMTExpr (BitVector a)) -> Maybe (SMTExpr (BitVector a))
bvBinOpOptimize BVAdd (Const (BitVector 0) _,y) = Just y
bvBinOpOptimize BVAdd (x,Const (BitVector 0) _) = Just x
bvBinOpOptimize BVAdd (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (bvRestrict (BitVector $ x+y) w) w)
bvBinOpOptimize BVAnd (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (BitVector $ x .&. y) w)
bvBinOpOptimize BVOr (Const (BitVector x) w,Const (BitVector y) _) = Just (Const (BitVector $ x .|. y) w)
bvBinOpOptimize BVOr (Const (BitVector 0) _,oth) = Just oth
bvBinOpOptimize BVOr (oth,Const (BitVector 0) _) = Just oth
bvBinOpOptimize BVSHL (Const (BitVector x) w,Const (BitVector y) _)
= Just (Const (bvRestrict (BitVector $ x `shiftL` (fromInteger y)) w) w)
bvBinOpOptimize BVSHL (Const (BitVector 0) w,_) = Just (Const (BitVector 0) w)
bvBinOpOptimize BVSHL (oth,Const (BitVector 0) w) = Just oth
bvBinOpOptimize _ _ = Nothing
bvCompOptimize :: IsBitVector a => SMTBVCompOp -> (SMTExpr (BitVector a),SMTExpr (BitVector a)) -> Maybe (SMTExpr Bool)
bvCompOptimize op (Const b1 ann1,Const b2 ann2)
= Just $ Const (case op of
BVULE -> u1 <= u2
BVULT -> u1 < u2
BVUGE -> u1 >= u2
BVUGT -> u1 > u2
BVSLE -> s1 <= s2
BVSLT -> s1 < s2
BVSGE -> s1 >= s2
BVSGT -> s1 > s2) ()
where
u1 = bvUnsigned b1 ann1
u2 = bvUnsigned b2 ann2
s1 = bvSigned b1 ann1
s2 = bvSigned b2 ann2
bvCompOptimize _ _ = Nothing
intArithOptimize :: SMTArithOp -> [SMTExpr Integer] -> Maybe (SMTExpr Integer)
intArithOptimize Plus xs
= let (consts,nonconsts) = partitionEithers $ fmap (\e -> case e of
Const i _ -> Left i
_ -> Right e
) xs
in case consts of
[] -> Nothing
[x] -> case nonconsts of
[] -> Just (Const x ())
[y] -> if x==0
then Just y
else Nothing
_ -> Nothing
_ -> let s = sum consts
in case nonconsts of
[] -> Just (Const s ())
[x] -> if s==0
then Just x
else Just (App (SMTArith Plus) [x,Const s ()])
_ -> Just (App (SMTArith Plus) (nonconsts++(if s==0
then []
else [Const s ()])))
intArithOptimize Mult xs
= let (consts,nonconsts) = partitionEithers $ fmap (\e -> case e of
Const i _ -> Left i
_ -> Right e
) xs
in case consts of
[] -> Nothing
[_] -> Nothing
_ -> case nonconsts of
[] -> Just (Const (product consts) ())
_ -> Just (App (SMTArith Mult) (nonconsts++[Const (product consts) ()]))
intMinusOptimize :: (SMTExpr Integer,SMTExpr Integer) -> Maybe (SMTExpr Integer)
intMinusOptimize (Const x _,Const y _) = Just (Const (xy) ())
intMinusOptimize (x,Const 0 _) = Just x
intMinusOptimize _ = Nothing
intCmpOptimize :: SMTOrdOp -> (SMTExpr Integer,SMTExpr Integer) -> Maybe (SMTExpr Bool)
intCmpOptimize op (Const x _,Const y _)
= Just (Const (case op of
Ge -> x >= y
Gt -> x > y
Le -> x <= y
Lt -> x < y) ())
intCmpOptimize _ _ = Nothing