{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE Rank2Types #-}
module Ivory.Opts.ConstFold
( constFold
) where
import Ivory.Opts.ConstFoldComp
import qualified Ivory.Language.Array as I
import Ivory.Language.Cast (toMaxSize, toMinSize)
import qualified Ivory.Language.Syntax as I
import Control.Arrow (second)
import qualified Data.DList as D
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import MonadLib (Id, StateM (..), StateT, runId,
runStateT)
type CopyMap = Map I.Var I.Expr
type ExprOpt = CopyMap -> I.Type -> I.Expr -> I.Expr
constFold :: I.Proc -> I.Proc
constFold = procFold cf
procFold :: ExprOpt -> I.Proc -> I.Proc
procFold opt proc =
let cxt = I.procSym proc
body' = D.toList $ blockFold cxt opt Map.empty $ I.procBody proc
in proc { I.procBody = body' }
blockFold :: String -> ExprOpt -> CopyMap -> I.Block -> D.DList I.Stmt
blockFold cxt opt copies = D.concat . fst . runId . runStateT copies . mapM (stmtFold cxt opt)
stmtFold :: String -> ExprOpt -> I.Stmt -> StateT CopyMap Id (D.DList I.Stmt)
stmtFold cxt opt stmt =
case stmt of
I.IfTE _ [] [] -> return D.empty
I.IfTE e [] b1 -> stmtFold cxt opt $ I.IfTE (I.ExpOp I.ExpNot [e]) b1 []
I.IfTE e b0 b1 -> do
copies <- get
case opt copies I.TyBool e of
I.ExpLit (I.LitBool b) -> fmap D.concat $ mapM (stmtFold cxt opt) $ if b then b0 else b1
e' -> return $ D.singleton $ I.IfTE e' (newFold copies b0) (newFold copies b1)
I.Assert e -> do
copies <- get
case opt copies I.TyBool e of
I.ExpLit (I.LitBool b) ->
if b then return D.empty
else error $ "Constant folding evaluated a False assert()"
++ " in evaluating expression " ++ show e
++ " of function " ++ cxt
e' -> return $ D.singleton (I.Assert e')
I.CompilerAssert e -> do
copies <- get
case opt copies I.TyBool e of
I.ExpLit (I.LitBool b) | b -> return D.empty
e' -> return $ D.singleton (I.CompilerAssert e')
I.Assume e -> do
copies <- get
case opt copies I.TyBool e of
I.ExpLit (I.LitBool b) ->
if b then return D.empty
else error $ "Constant folding evaluated a False assume()"
++ " in evaluating expression " ++ show e
++ " of function " ++ cxt
e' -> return $ D.singleton (I.Assume e')
I.Return e -> do
copies <- get
return $ D.singleton $ I.Return (typedFold opt copies e)
I.ReturnVoid -> return $ D.singleton stmt
I.Deref t var e -> do
copies <- get
return $ D.singleton $ I.Deref t var (opt copies t e)
I.Store t e0 e1 -> do
copies <- get
return $ D.singleton $ I.Store t (opt copies t e0) (opt copies t e1)
I.Assign t v e -> do
copies <- get
let e' = opt copies t e
let copyProp = set (Map.insert v e' copies) >> return D.empty
case e' of
I.ExpSym{} -> copyProp
I.ExpVar{} -> copyProp
I.ExpLit{} -> copyProp
I.ExpAddrOfGlobal{} -> copyProp
I.ExpMaxMin{} -> copyProp
_ -> return $ D.singleton $ I.Assign t v e'
I.Call t mv c tys -> do
copies <- get
return $ D.singleton $ I.Call t mv c (map (typedFold opt copies) tys)
I.Local t var i -> do
copies <- get
return $ D.singleton $ I.Local t var $ constFoldInits copies i
I.RefCopy t e0 e1 -> do
copies <- get
return $ D.singleton $ I.RefCopy t (opt copies t e0) (opt copies t e1)
I.RefZero t e0 -> do
copies <- get
return $ D.singleton $ I.RefZero t (opt copies t e0)
I.AllocRef{} -> return $ D.singleton stmt
I.Loop m v e incr blk' -> do
copies <- get
let ty = I.ixRep
case opt copies ty e of
I.ExpLit (I.LitBool b) ->
if b then error $ "Constant folding evaluated True expression "
++ "in a loop bound. The loop will never terminate!"
else error $ "Constant folding evaluated False expression "
++ "in a loop bound. The loop will never execute!"
_ ->
return $ D.singleton $ I.Loop m v (opt copies ty e) (loopIncrFold (opt copies ty) incr)
(newFold copies blk')
I.Break -> return $ D.singleton stmt
I.Forever b -> do
copies <- get
return $ D.singleton $ I.Forever (newFold copies b)
I.Comment{} -> return $ D.singleton stmt
where
newFold copies = D.toList . blockFold cxt opt copies
constFoldInits :: CopyMap -> I.Init -> I.Init
constFoldInits _ I.InitZero = I.InitZero
constFoldInits copies (I.InitExpr ty expr) =
I.InitExpr ty $ cf copies ty expr
constFoldInits copies (I.InitStruct i) =
I.InitStruct $ map (second (constFoldInits copies)) i
constFoldInits copies (I.InitArray i b) =
I.InitArray (map (constFoldInits copies) i) b
cf :: ExprOpt
cf copies ty e =
case e of
I.ExpSym{} -> e
I.ExpExtern{} -> e
I.ExpVar v -> Map.findWithDefault e v copies
I.ExpLit{} -> e
I.ExpOp op args -> liftChoice copies ty op args
I.ExpLabel t e0 s -> I.ExpLabel t (cf copies t e0) s
I.ExpIndex t e0 t1 e1 -> I.ExpIndex t (cf copies t e0) t1 (cf copies t1 e1)
I.ExpSafeCast t e0 ->
let e0' = cf copies t e0
in fromMaybe (I.ExpSafeCast t e0') $ do
_ <- destLit e0'
return e0'
I.ExpToIx e0 maxSz ->
let ty' = I.ixRep in
let e0' = cf copies ty' e0 in
case destIntegerLit e0' of
Just i -> I.ExpLit $ I.LitInteger $ i `rem` maxSz
Nothing -> I.ExpToIx e0' maxSz
I.ExpAddrOfGlobal{} -> e
I.ExpMaxMin{} -> e
I.ExpSizeOf{} -> e
loopIncrFold :: (I.Expr -> I.Expr) -> I.LoopIncr -> I.LoopIncr
loopIncrFold opt incr =
case incr of
I.IncrTo e0 -> I.IncrTo (opt e0)
I.DecrTo e0 -> I.DecrTo (opt e0)
typedFold :: ExprOpt -> CopyMap -> I.Typed I.Expr -> I.Typed I.Expr
typedFold opt copies tval@(I.Typed ty val) = tval { I.tValue = opt copies ty val }
arg0 :: [a] -> a
arg0 = flip (!!) 0
arg1 :: [a] -> a
arg1 = flip (!!) 1
arg2 :: [a] -> a
arg2 = flip (!!) 2
cfOp :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
cfOp copies ty op args = cfOp' ty op $ case op of
I.ExpEq t -> cfargs t args
I.ExpNeq t -> cfargs t args
I.ExpCond -> let (cond, rest) = splitAt 1 args in cfargs I.TyBool cond ++ cfargs ty rest
I.ExpGt _ t -> cfargs t args
I.ExpLt _ t -> cfargs t args
I.ExpIsNan t -> cfargs t args
I.ExpIsInf t -> cfargs t args
_ -> cfargs ty args
where
cfargs ty' = mkCfArgs ty' . map (cf copies ty')
cfOp' :: I.Type -> I.ExpOp -> [CfVal] -> I.Expr
cfOp' ty op args = case op of
I.ExpEq _ -> cfOrd
I.ExpNeq _ -> cfOrd
I.ExpCond
| CfBool b <- arg0 args
-> if b then a1 else a2
| ty == I.TyBool && arg1 args == CfBool True -> cfOp' ty I.ExpOr [arg0 args, arg2 args]
| ty == I.TyBool && arg1 args == CfBool False -> cfOp' ty I.ExpAnd $ mkCfArgs ty [cfOp' ty I.ExpNot [arg0 args]] ++ [arg2 args]
| ty == I.TyBool && arg2 args == CfBool True -> cfOp' ty I.ExpOr $ mkCfArgs ty [cfOp' ty I.ExpNot [arg0 args]] ++ [arg1 args]
| ty == I.TyBool && arg2 args == CfBool False -> cfOp' ty I.ExpAnd [arg0 args, arg1 args]
| a1 == a2
-> a1
| otherwise -> noop
where a1 = toExpr $ arg1 args
a2 = toExpr $ arg2 args
I.ExpGt orEq t
| orEq -> goOrd t gteCheck args
| otherwise -> goOrd t gtCheck args
I.ExpLt orEq t
| orEq -> goOrd t gteCheck (reverse args)
| otherwise -> goOrd t gtCheck (reverse args)
I.ExpNot -> case arg0 args of
CfBool b -> I.ExpLit (I.LitBool (not b))
CfExpr (I.ExpOp (I.ExpEq t) args') -> I.ExpOp (I.ExpNeq t) args'
CfExpr (I.ExpOp (I.ExpNeq t) args') -> I.ExpOp (I.ExpEq t) args'
CfExpr (I.ExpOp (I.ExpGt orEq t) args') -> I.ExpOp (I.ExpLt (not orEq) t) args'
CfExpr (I.ExpOp (I.ExpLt orEq t) args') -> I.ExpOp (I.ExpGt (not orEq) t) args'
_ -> noop
I.ExpAnd
| CfBool lb <- arg0 args
, CfBool rb <- arg1 args
-> I.ExpLit (I.LitBool (lb && rb))
| CfBool lb <- arg0 args
-> if lb then toExpr $ arg1 args else I.ExpLit (I.LitBool False)
| CfBool rb <- arg1 args
-> if rb then toExpr $ arg0 args else I.ExpLit (I.LitBool False)
| otherwise -> noop
I.ExpOr
| CfBool lb <- arg0 args
, CfBool rb <- arg1 args
-> I.ExpLit (I.LitBool (lb || rb))
| CfBool lb <- arg0 args
-> if lb then I.ExpLit (I.LitBool True) else toExpr $ arg1 args
| CfBool rb <- arg1 args
-> if rb then I.ExpLit (I.LitBool True) else toExpr $ arg0 args
| otherwise -> noop
I.ExpMul
| isLitValue 0 $ arg0 args -> toExpr $ arg0 args
| isLitValue 1 $ arg0 args -> toExpr $ arg1 args
| isLitValue (-1) $ arg0 args -> cfOp' ty I.ExpNegate [arg1 args]
| CfExpr (I.ExpOp I.ExpNegate [e']) <- arg0 args -> cfOp' ty I.ExpNegate $ mkCfArgs ty [cfOp' ty I.ExpMul $ mkCfArgs ty [e'] ++ [arg1 args]]
| isLitValue 0 $ arg1 args -> toExpr $ arg1 args
| isLitValue 1 $ arg1 args -> toExpr $ arg0 args
| isLitValue (-1) $ arg1 args -> cfOp' ty I.ExpNegate [arg0 args]
| CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpNegate $ mkCfArgs ty [cfOp' ty I.ExpMul $ arg0 args : mkCfArgs ty [e']]
| otherwise -> goNum
I.ExpAdd
| isLitValue 0 $ arg0 args -> toExpr $ arg1 args
| isLitValue 0 $ arg1 args -> toExpr $ arg0 args
| CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpSub $ arg0 args : mkCfArgs ty [e']
| otherwise -> goNum
I.ExpSub
| isLitValue 0 $ arg0 args -> cfOp' ty I.ExpNegate [arg1 args]
| isLitValue 0 $ arg1 args -> toExpr $ arg0 args
| CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpAdd $ arg0 args : mkCfArgs ty [e']
| otherwise -> goNum
I.ExpNegate -> case arg0 args of
CfExpr (I.ExpOp I.ExpNegate [e']) -> e'
CfExpr (I.ExpOp I.ExpSub [e1, e2]) -> cfOp' ty I.ExpSub $ mkCfArgs ty [e2, e1]
_ -> goNum
I.ExpAbs -> goNum
I.ExpSignum -> goNum
I.ExpDiv -> goI2
I.ExpMod -> goI2
I.ExpRecip -> goF
I.ExpIsNan _ -> goFB
I.ExpIsInf _ -> goFB
I.ExpFExp -> goF
I.ExpFSqrt -> goF
I.ExpFLog -> goF
I.ExpFPow -> goF
I.ExpFLogBase -> goF
I.ExpFSin -> goF
I.ExpFCos -> goF
I.ExpFTan -> goF
I.ExpFAsin -> goF
I.ExpFAcos -> goF
I.ExpFAtan -> goF
I.ExpFAtan2 -> goF
I.ExpFSinh -> goF
I.ExpFCosh -> goF
I.ExpFTanh -> goF
I.ExpFAsinh -> goF
I.ExpFAcosh -> goF
I.ExpFAtanh -> goF
I.ExpBitAnd -> toExpr (cfBitAnd ty args)
I.ExpBitOr -> toExpr (cfBitOr ty args)
I.ExpRoundF -> noop
I.ExpCeilF -> noop
I.ExpFloorF -> noop
I.ExpBitXor -> noop
I.ExpBitComplement -> noop
I.ExpBitShiftL -> noop
I.ExpBitShiftR -> noop
where
noop = I.ExpOp op $ map toExpr args
goI2 = toExpr (cfIntOp2 ty op args)
goF = toExpr (cfFloating op args)
goFB = toExpr (cfFloatingB op args)
cfOrd = toExpr (cfOrd2 op args)
goOrd ty' chk args' = fromOrdChecks cfOrd (chk ty' args')
goNum = toExpr (cfNum ty op args)
liftChoice :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
liftChoice copies ty op args = case op of
I.ExpEq{} -> go2
I.ExpNeq{} -> go2
I.ExpGt{} -> go2
I.ExpLt{} -> go2
I.ExpNot{} -> go1
I.ExpAnd{} -> go2
I.ExpOr{} -> go2
I.ExpMul -> go2
I.ExpAdd -> go2
I.ExpSub -> go2
I.ExpNegate -> go1
I.ExpAbs -> go1
I.ExpSignum -> go1
_ -> cfOp copies ty op args
where
go1 = unOpLift copies ty op args
go2 = binOpLift copies ty op args
unOpLift :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
unOpLift copies ty op args = case a0 of
I.ExpOp I.ExpCond [_,x1,x2]
-> let a = lt x1 in
if a == lt x2 then a else c
_ -> c
where
a0 = arg0 args
lt x = liftChoice copies ty op [x]
c = cfOp copies ty op args
binOpLift :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
binOpLift copies ty op args = case a0 of
I.ExpOp I.ExpCond [_,x1,x2]
-> let a = lt0 x1 in
if a == lt0 x2 then a else c
_ -> case a1 of
I.ExpOp I.ExpCond [_,x1,x2]
-> let a = lt1 x1 in
if a == lt1 x2 then a else c
_ -> c
where
a0 = arg0 args
a1 = arg1 args
lt0 x = lt x a1
lt1 x = lt a0 x
lt a b = liftChoice copies ty op [a, b]
c = cfOp copies ty op args
gteCheck :: I.Type -> [CfVal] -> Maybe Bool
gteCheck t [l,r]
| CfInteger _ x <- l
, Just s <- toMaxSize t
, x == s
= Just True
| CfInteger _ y <- r
, Just s <- toMinSize t
, y == s
= Just True
| otherwise
= Nothing
gteCheck _ _ = err "wrong number of args to gtCheck."
gtCheck :: I.Type -> [CfVal] -> Maybe Bool
gtCheck t [l,r]
| CfInteger _ x <- l
, Just s <- toMinSize t
, x == s
= Just False
| CfInteger _ y <- r
, Just s <- toMaxSize t
, y == s
= Just False
| otherwise
= Nothing
gtCheck _ _ = err "wrong number of args to gtCheck."
fromOrdChecks :: I.Expr -> Maybe Bool -> I.Expr
fromOrdChecks expr = maybe expr (toExpr . CfBool)
cfOrd2 :: I.ExpOp
-> [CfVal]
-> CfVal
cfOrd2 op [l,r] = case (l,r) of
(CfBool x, CfBool y) -> CfBool (op' x y)
(CfInteger _ x,CfInteger _ y) -> CfBool (op' x y)
(CfFloat x, CfFloat y) -> CfBool (op' x y)
(CfDouble x, CfDouble y) -> CfBool (op' x y)
_ -> CfExpr (I.ExpOp op [toExpr l, toExpr r])
where
op' :: Ord a => a -> a -> Bool
op' = case op of
I.ExpEq _ -> (==)
I.ExpNeq _ -> (/=)
I.ExpGt orEq _
| orEq -> (>=)
| otherwise -> (>)
I.ExpLt orEq _
| orEq -> (<=)
| otherwise -> (<)
_ -> err "bad op to cfOrd2"
cfOrd2 _ _ = err "wrong number of args to cfOrd2"