{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE Rank2Types    #-}

--
-- Constant folder.
--
-- Copyright (C) 2014, Galois, Inc.
-- All rights reserved.
--

module Ivory.Opts.ConstFold
  ( constFold
  ) where

import           Ivory.Opts.ConstFoldComp

import qualified Ivory.Language.Array     as I
import           Ivory.Language.Cast      (toMaxSize, toMinSize)
import qualified Ivory.Language.Syntax    as I

import           Control.Arrow            (second)
import qualified Data.DList               as D
import           Data.Map                 (Map)
import qualified Data.Map                 as Map
import           Data.Maybe
import           MonadLib                 (Id, StateM (..), StateT, runId,
                                           runStateT)

--------------------------------------------------------------------------------
-- Constant folding

type CopyMap = Map I.Var I.Expr

-- | Expression to expression optimization.
type ExprOpt = CopyMap -> I.Type -> I.Expr -> I.Expr

constFold :: I.Proc -> I.Proc
constFold = procFold cf

procFold :: ExprOpt -> I.Proc -> I.Proc
procFold opt proc =
  let cxt   = I.procSym proc
      body' = D.toList $ blockFold cxt opt Map.empty $ I.procBody proc
   in proc { I.procBody = body' }

blockFold :: String -> ExprOpt -> CopyMap -> I.Block -> D.DList I.Stmt
blockFold cxt opt copies = D.concat . fst . runId . runStateT copies . mapM (stmtFold cxt opt)

stmtFold :: String -> ExprOpt -> I.Stmt -> StateT CopyMap Id (D.DList I.Stmt)
stmtFold cxt opt stmt =
  case stmt of
    I.IfTE _ [] []       -> return D.empty
    I.IfTE e [] b1       -> stmtFold cxt opt $ I.IfTE (I.ExpOp I.ExpNot [e]) b1 []
    I.IfTE e b0 b1       -> do
      copies <- get
      case opt copies I.TyBool e of
        I.ExpLit (I.LitBool b) -> fmap D.concat $ mapM (stmtFold cxt opt) $ if b then b0 else b1
        e'                     -> return $ D.singleton $ I.IfTE e' (newFold copies b0) (newFold copies b1)
    I.Assert e           -> do
      copies <- get
      case opt copies I.TyBool e of
        I.ExpLit (I.LitBool b) ->
          if b then return D.empty
            else error $ "Constant folding evaluated a False assert()"
                       ++ " in evaluating expression " ++ show e
                       ++ " of function " ++ cxt
        e'                     -> return $ D.singleton (I.Assert e')
    I.CompilerAssert e        -> do
      copies <- get
      case opt copies I.TyBool e of
        -- It's OK to have false but unreachable compiler asserts.
        I.ExpLit (I.LitBool b) | b -> return D.empty
        e'                         -> return $ D.singleton (I.CompilerAssert e')
    I.Assume e           -> do
      copies <- get
      case opt copies I.TyBool e of
        I.ExpLit (I.LitBool b) ->
          if b then return D.empty
            else error $ "Constant folding evaluated a False assume()"
                       ++ " in evaluating expression " ++ show e
                       ++ " of function " ++ cxt
        e'                     -> return $ D.singleton (I.Assume e')

    I.Return e           -> do
      copies <- get
      return $ D.singleton $ I.Return (typedFold opt copies e)
    I.ReturnVoid         -> return $ D.singleton stmt
    I.Deref t var e      -> do
      copies <- get
      return $ D.singleton $ I.Deref t var (opt copies t e)
    I.Store t e0 e1      -> do
      copies <- get
      return $ D.singleton $ I.Store t (opt copies t e0) (opt copies t e1)

    I.Assign t v e       -> do
      copies <- get
      let e' = opt copies t e
      let copyProp = set (Map.insert v e' copies) >> return D.empty
      case e' of
        I.ExpSym{}          -> copyProp
        I.ExpVar{}          -> copyProp
        I.ExpLit{}          -> copyProp
        I.ExpAddrOfGlobal{} -> copyProp
        I.ExpMaxMin{}       -> copyProp
        _                   -> return $ D.singleton $ I.Assign t v e'

    I.Call t mv c tys    -> do
      copies <- get
      return $ D.singleton $ I.Call t mv c (map (typedFold opt copies) tys)
    I.Local t var i      -> do
      copies <- get
      return $ D.singleton $ I.Local t var $ constFoldInits copies i
    I.RefCopy t e0 e1    -> do
      copies <- get
      return $ D.singleton $ I.RefCopy t (opt copies t e0) (opt copies t e1)
    I.RefZero t e0       -> do
      copies <- get
      return $ D.singleton $ I.RefZero t (opt copies t e0)
    I.AllocRef{}         -> return $ D.singleton stmt
    I.Loop m v e incr blk' -> do
      copies <- get
      let ty = I.ixRep
      case opt copies ty e of
        I.ExpLit (I.LitBool b) ->
          if b then error $ "Constant folding evaluated True expression "
                          ++ "in a loop bound.  The loop will never terminate!"
               else error $ "Constant folding evaluated False expression "
                          ++ "in a loop bound.  The loop will never execute!"
        _                      ->
          return $ D.singleton $ I.Loop m v (opt copies ty e) (loopIncrFold (opt copies ty) incr)
                        (newFold copies blk')
    I.Break              -> return $ D.singleton stmt
    I.Forever b          -> do
      copies <- get
      return $ D.singleton $ I.Forever (newFold copies b)
    I.Comment{}          -> return $ D.singleton stmt
  where
  newFold copies = D.toList . blockFold cxt opt copies

constFoldInits :: CopyMap -> I.Init -> I.Init
constFoldInits _ I.InitZero = I.InitZero
constFoldInits copies (I.InitExpr ty expr) =
  I.InitExpr ty $ cf copies ty expr
constFoldInits copies (I.InitStruct i)     =
  I.InitStruct $ map (second (constFoldInits copies)) i
constFoldInits copies (I.InitArray i b)    =
  I.InitArray (map (constFoldInits copies) i) b

--------------------------------------------------------------------------------
-- Expressions

-- | Constant folding over expressions.
cf :: ExprOpt
cf copies ty e =
  case e of
    I.ExpSym{}            -> e
    I.ExpExtern{}         -> e
    I.ExpVar v            -> Map.findWithDefault e v copies
    I.ExpLit{}            -> e

    I.ExpOp op args       -> liftChoice copies ty op args

    I.ExpLabel t e0 s     -> I.ExpLabel t (cf copies t e0) s

    I.ExpIndex t e0 t1 e1 -> I.ExpIndex t (cf copies t e0) t1 (cf copies t1 e1)

    I.ExpSafeCast t e0    ->
      let e0' = cf copies t e0
       in fromMaybe (I.ExpSafeCast t e0') $ do
            _ <- destLit e0'
            return e0'

    I.ExpToIx e0 maxSz    ->
      let ty' = I.ixRep in
      let e0' = cf copies ty' e0 in
      case destIntegerLit e0' of
        Just i  -> I.ExpLit $ I.LitInteger $ i `rem` maxSz
        Nothing -> I.ExpToIx e0' maxSz

    I.ExpAddrOfGlobal{}   -> e
    I.ExpMaxMin{}         -> e
    I.ExpSizeOf{}         -> e

loopIncrFold :: (I.Expr -> I.Expr) -> I.LoopIncr -> I.LoopIncr
loopIncrFold opt incr =
  case incr of
    I.IncrTo e0 -> I.IncrTo (opt e0)
    I.DecrTo e0 -> I.DecrTo (opt e0)

--------------------------------------------------------------------------------

typedFold :: ExprOpt -> CopyMap -> I.Typed I.Expr -> I.Typed I.Expr
typedFold opt copies tval@(I.Typed ty val) = tval { I.tValue = opt copies ty val }

arg0 :: [a] -> a
arg0 = flip (!!) 0

arg1 :: [a] -> a
arg1 = flip (!!) 1

arg2 :: [a] -> a
arg2 = flip (!!) 2

-- | Reconstruct an operator, folding away operations when possible.
cfOp :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
cfOp copies ty op args = cfOp' ty op $ case op of
  I.ExpEq t -> cfargs t args
  I.ExpNeq t -> cfargs t args
  I.ExpCond -> let (cond, rest) = splitAt 1 args in cfargs I.TyBool cond ++ cfargs ty rest
  I.ExpGt _ t -> cfargs t args
  I.ExpLt _ t -> cfargs t args
  I.ExpIsNan t  -> cfargs t args
  I.ExpIsInf t  -> cfargs t args
  _ -> cfargs ty args
  where
  cfargs ty' = mkCfArgs ty' . map (cf copies ty')

cfOp' :: I.Type -> I.ExpOp -> [CfVal] -> I.Expr
cfOp' ty op args = case op of
  I.ExpEq _  -> cfOrd
  I.ExpNeq _ -> cfOrd
  I.ExpCond
    | CfBool b <- arg0 args
    -> if b then a1 else a2
    -- If either branch is a boolean literal, reduce to logical AND or OR.
    | ty == I.TyBool && arg1 args == CfBool True -> cfOp' ty I.ExpOr [arg0 args, arg2 args]
    | ty == I.TyBool && arg1 args == CfBool False -> cfOp' ty I.ExpAnd $ mkCfArgs ty [cfOp' ty I.ExpNot [arg0 args]] ++ [arg2 args]
    | ty == I.TyBool && arg2 args == CfBool True -> cfOp' ty I.ExpOr $ mkCfArgs ty [cfOp' ty I.ExpNot [arg0 args]] ++ [arg1 args]
    | ty == I.TyBool && arg2 args == CfBool False -> cfOp' ty I.ExpAnd [arg0 args, arg1 args]
    -- If both branches have the same result, we dont care about the branch
    -- condition.  XXX This can be expensive
    | a1 == a2
    -> a1
    | otherwise -> noop
    where a1 = toExpr $ arg1 args
          a2 = toExpr $ arg2 args
  I.ExpGt orEq t
    | orEq      -> goOrd t gteCheck args
    | otherwise -> goOrd t gtCheck args
  I.ExpLt orEq t
    | orEq      -> goOrd t gteCheck (reverse args)
    | otherwise -> goOrd t gtCheck  (reverse args)
  I.ExpNot -> case arg0 args of
    CfBool b -> I.ExpLit (I.LitBool (not b))
    CfExpr (I.ExpOp (I.ExpEq t) args') -> I.ExpOp (I.ExpNeq t) args'
    CfExpr (I.ExpOp (I.ExpNeq t) args') -> I.ExpOp (I.ExpEq t) args'
    CfExpr (I.ExpOp (I.ExpGt orEq t) args') -> I.ExpOp (I.ExpLt (not orEq) t) args'
    CfExpr (I.ExpOp (I.ExpLt orEq t) args') -> I.ExpOp (I.ExpGt (not orEq) t) args'
    _ -> noop
  I.ExpAnd
    | CfBool lb <- arg0 args
    , CfBool rb <- arg1 args
    -> I.ExpLit (I.LitBool (lb && rb))
    | CfBool lb <- arg0 args
    -> if lb then toExpr $ arg1 args else I.ExpLit (I.LitBool False)
    | CfBool rb <- arg1 args
    -> if rb then toExpr $ arg0 args else I.ExpLit (I.LitBool False)
    | otherwise -> noop
  I.ExpOr
    | CfBool lb <- arg0 args
    , CfBool rb <- arg1 args
    -> I.ExpLit (I.LitBool (lb || rb))
    | CfBool lb <- arg0 args
    -> if lb then I.ExpLit (I.LitBool True) else toExpr $ arg1 args
    | CfBool rb <- arg1 args
    -> if rb then I.ExpLit (I.LitBool True) else toExpr $ arg0 args
    | otherwise -> noop

  I.ExpMul
    | isLitValue 0 $ arg0 args -> toExpr $ arg0 args
    | isLitValue 1 $ arg0 args -> toExpr $ arg1 args
    | isLitValue (-1) $ arg0 args -> cfOp' ty I.ExpNegate [arg1 args]
    | CfExpr (I.ExpOp I.ExpNegate [e']) <- arg0 args -> cfOp' ty I.ExpNegate $ mkCfArgs ty [cfOp' ty I.ExpMul $ mkCfArgs ty [e'] ++ [arg1 args]]
    | isLitValue 0 $ arg1 args -> toExpr $ arg1 args
    | isLitValue 1 $ arg1 args -> toExpr $ arg0 args
    | isLitValue (-1) $ arg1 args -> cfOp' ty I.ExpNegate [arg0 args]
    | CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpNegate $ mkCfArgs ty [cfOp' ty I.ExpMul $ arg0 args : mkCfArgs ty [e']]
    | otherwise -> goNum

  I.ExpAdd
    | isLitValue 0 $ arg0 args -> toExpr $ arg1 args
    | isLitValue 0 $ arg1 args -> toExpr $ arg0 args
    | CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpSub $ arg0 args : mkCfArgs ty [e']
    | otherwise -> goNum

  I.ExpSub
    | isLitValue 0 $ arg0 args -> cfOp' ty I.ExpNegate [arg1 args]
    | isLitValue 0 $ arg1 args -> toExpr $ arg0 args
    | CfExpr (I.ExpOp I.ExpNegate [e']) <- arg1 args -> cfOp' ty I.ExpAdd $ arg0 args : mkCfArgs ty [e']
    | otherwise -> goNum

  I.ExpNegate   -> case arg0 args of
    CfExpr (I.ExpOp I.ExpNegate [e']) -> e'
    CfExpr (I.ExpOp I.ExpSub [e1, e2]) -> cfOp' ty I.ExpSub $ mkCfArgs ty [e2, e1]
    _ -> goNum

  I.ExpAbs      -> goNum
  I.ExpSignum   -> goNum

  I.ExpDiv      -> goI2
  I.ExpMod      -> goI2
  I.ExpRecip    -> goF

  I.ExpIsNan _  -> goFB
  I.ExpIsInf _  -> goFB

  I.ExpFExp     -> goF
  I.ExpFSqrt    -> goF
  I.ExpFLog     -> goF
  I.ExpFPow     -> goF
  I.ExpFLogBase -> goF
  I.ExpFSin     -> goF
  I.ExpFCos     -> goF
  I.ExpFTan     -> goF
  I.ExpFAsin    -> goF
  I.ExpFAcos    -> goF
  I.ExpFAtan    -> goF
  I.ExpFAtan2   -> goF
  I.ExpFSinh    -> goF
  I.ExpFCosh    -> goF
  I.ExpFTanh    -> goF
  I.ExpFAsinh   -> goF
  I.ExpFAcosh   -> goF
  I.ExpFAtanh   -> goF

  I.ExpBitAnd        -> toExpr (cfBitAnd ty args)
  I.ExpBitOr         -> toExpr (cfBitOr ty args)

  -- Unimplemented right now
  I.ExpRoundF        -> noop
  I.ExpCeilF         -> noop
  I.ExpFloorF        -> noop
  I.ExpBitXor        -> noop
  I.ExpBitComplement -> noop
  I.ExpBitShiftL     -> noop
  I.ExpBitShiftR     -> noop

  where
  noop          = I.ExpOp op $ map toExpr args
  goI2          = toExpr (cfIntOp2 ty op args)
  goF           = toExpr (cfFloating op args)
  goFB          = toExpr (cfFloatingB op args)
  cfOrd         = toExpr (cfOrd2 op args)
  goOrd ty' chk args' = fromOrdChecks cfOrd (chk ty' args')
  goNum         = toExpr (cfNum ty op args)

--------------------------------------------------------------------------------

-- | Lift nondeterministic choice up see see if we can further optimize.
liftChoice :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
liftChoice copies ty op args = case op of
  I.ExpEq{}   -> go2
  I.ExpNeq{}  -> go2
  -- I.ExpCond --unnecessary
  I.ExpGt{}   -> go2
  I.ExpLt{}   -> go2

  I.ExpNot{}  -> go1
  I.ExpAnd{}  -> go2
  I.ExpOr{}   -> go2

  I.ExpMul    -> go2
  I.ExpAdd    -> go2
  I.ExpSub    -> go2
  I.ExpNegate -> go1
  I.ExpAbs    -> go1
  I.ExpSignum -> go1

  -- -- NOT SAFE TO LIFT!
  -- I.ExpDiv      -> --NO!

  -- Unimplemented currently: add as needed
  -- I.ExpMod      ->
  -- I.ExpRecip    ->
  -- I.ExpIsNan{}  ->
  -- I.ExpIsInf{}  ->
  -- I.ExpFExp     ->
  -- I.ExpFSqrt    ->
  -- I.ExpFLog     ->
  -- I.ExpFPow     ->
  -- I.ExpFLogBase ->
  -- I.ExpFSin     ->
  -- I.ExpFCos     ->
  -- I.ExpFTan     ->
  -- I.ExpFAsin    ->
  -- I.ExpFAcos    ->
  -- I.ExpFAtan    ->
  -- I.ExpFAtan2   ->
  -- I.ExpFSinh    ->
  -- I.ExpFCosh    ->
  -- I.ExpFTanh    ->
  -- I.ExpFAsinh   ->
  -- I.ExpFAcosh   ->
  -- I.ExpFAtanh   ->
  -- I.ExpBitAnd        ->
  -- I.ExpBitOr         ->
  -- -- Unimplemented right now
  -- I.ExpRoundF        ->
  -- I.ExpCeilF         ->
  -- I.ExpFloorF        ->
  -- I.ExpBitXor        ->
  -- I.ExpBitComplement ->
  -- I.ExpBitShiftL     ->
  -- I.ExpBitShiftR     ->
  _ -> cfOp copies ty op args
  where
  go1 = unOpLift  copies ty op args
  go2 = binOpLift copies ty op args


--XXX the equality comparisons below can be expensive.  Hashmap?  Also, awkward
-- style, but I want sharing of (liftChoice ...) expression in branch condition
-- and result.
unOpLift :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
unOpLift copies ty op args = case a0 of
  I.ExpOp I.ExpCond [_,x1,x2]
    -> let a = lt x1 in
       if a == lt x2 then a else c
  _ -> c
  where
  a0     = arg0 args
  lt x   = liftChoice copies ty op [x]
  c      = cfOp copies ty op args

binOpLift :: CopyMap -> I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
binOpLift copies ty op args = case a0 of
  I.ExpOp I.ExpCond [_,x1,x2]
    -> let a = lt0 x1 in
       if a == lt0 x2 then a else c
  _ -> case a1 of
         I.ExpOp I.ExpCond [_,x1,x2]
           -> let a = lt1 x1 in
              if a == lt1 x2 then a else c
         _ -> c
  where
  a0     = arg0 args
  a1     = arg1 args
  lt0 x  = lt x a1
  lt1 x  = lt a0 x
  lt a b = liftChoice copies ty op [a, b]
  c      = cfOp copies ty op args

--------------------------------------------------------------------------------
-- Constant-folded values

-- | Check if we're comparing the max or min bound for >= and optimize.
-- Assumes args are already folded.
gteCheck :: I.Type -> [CfVal] -> Maybe Bool
gteCheck t [l,r]
  -- forall a. max >= a
  | CfInteger _ x <- l
  , Just s <- toMaxSize t
  , x == s
  = Just True
  -- forall a. a >= min
  | CfInteger _ y <- r
  , Just s <- toMinSize t
  , y == s
  = Just True
  | otherwise
  = Nothing
gteCheck _ _ = err "wrong number of args to gtCheck."

-- | Check if we're comparing the max or min bound for > and optimize.
-- Assumes args are already folded.
gtCheck :: I.Type -> [CfVal] -> Maybe Bool
gtCheck t [l,r]
  -- forall a. not (min > a)
  | CfInteger _ x <- l
  , Just s <- toMinSize t
  , x == s
  = Just False
  -- forall a. not (a > max)
  | CfInteger _ y <- r
  , Just s <- toMaxSize t
  , y == s
  = Just False
  | otherwise
  = Nothing
gtCheck _ _ = err "wrong number of args to gtCheck."

fromOrdChecks :: I.Expr -> Maybe Bool -> I.Expr
fromOrdChecks expr = maybe expr (toExpr . CfBool)

-- | Apply a binary operation that requires an ord instance.
cfOrd2 :: I.ExpOp
       -> [CfVal]
       -> CfVal
cfOrd2 op [l,r] = case (l,r) of
  (CfBool x,   CfBool y)        -> CfBool (op' x y)
  (CfInteger _ x,CfInteger _ y) -> CfBool (op' x y)
  (CfFloat x,  CfFloat y)       -> CfBool (op' x y)
  (CfDouble x, CfDouble y)      -> CfBool (op' x y)
  _                         -> CfExpr (I.ExpOp op [toExpr l, toExpr r])
  where
  op' :: Ord a => a -> a -> Bool
  op' = case op of
    I.ExpEq _     -> (==)
    I.ExpNeq _    -> (/=)
    I.ExpGt orEq _
      | orEq      -> (>=)
      | otherwise -> (>)
    I.ExpLt orEq _
      | orEq      -> (<=)
      | otherwise -> (<)
    _ -> err "bad op to cfOrd2"
cfOrd2 _ _ = err "wrong number of args to cfOrd2"