{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE Rank2Types #-}

module Ivory.Opts.ConstFold
  ( constFold
  ) where

import qualified Ivory.Language.Syntax.AST as I
import qualified Ivory.Language.Syntax.Type as I
import Ivory.Language.Cast (toMaxSize, toMinSize)

import Control.Monad (mzero,msum)
import Data.Maybe
import Data.List
import Data.Word
import Data.Int
import qualified Data.DList as D

--------------------------------------------------------------------------------
-- Constant folding
--------------------------------------------------------------------------------

constFold :: I.Proc -> I.Proc
constFold = procFold cf

-- | Expression to expression optimization.
type ExprOpt = I.Type -> I.Expr -> I.Expr

-- | Constant folding.
cf :: ExprOpt
cf ty e =
  case e of
    I.ExpSym{} -> e
    I.ExpVar{} -> e
    I.ExpLit{} -> e

    I.ExpOp op args       -> cfOp ty op args

    I.ExpLabel t e0 s     -> I.ExpLabel t (cf t e0) s

    I.ExpIndex t e0 t1 e1 -> I.ExpIndex t (cf t e0) t1 (cf t e1)

    I.ExpSafeCast t e0    ->
      let e0' = cf t e0
       in fromMaybe (I.ExpSafeCast t e0') $ do
            _ <- destLit e0'
            return e0'

    I.ExpToIx e0 maxSz    ->
      let ty' = I.TyInt I.Int32 in
      let e0' = cf ty' e0 in
      case destIntegerLit e0' of
        Just i  -> I.ExpLit $ I.LitInteger $ i `rem` maxSz
        Nothing -> I.ExpToIx e0' maxSz

    I.ExpAddrOfGlobal{}   -> e
    I.ExpMaxMin{}         -> e

procFold :: ExprOpt -> I.Proc -> I.Proc
procFold opt proc =
  let cxt   = I.procSym proc
      body' = D.toList $ foldl' (stmtFold cxt opt) D.empty (I.procBody proc)
   in proc { I.procBody = body' }

stmtFold :: String -> ExprOpt -> D.DList I.Stmt -> I.Stmt -> D.DList I.Stmt
stmtFold cxt opt blk stmt =
  case stmt of
    I.IfTE e b0 b1       ->
      let e' = opt I.TyBool e in
      case e' of
        I.ExpLit (I.LitBool b) -> if b then blk `D.append` (newFold' b0)
                                    else blk `D.append` (newFold' b1)
        _                      -> snoc $ I.IfTE e' (newFold b0) (newFold b1)
    I.Assert e           ->
      let e' = opt I.TyBool e in
      case e' of
        I.ExpLit (I.LitBool b) ->
          if b then blk
            else error $ "Constant folding evaluated a False assert()"
                       ++ " in evaluating expression " ++ show e
                       ++ " of function " ++ cxt
        _                      -> snoc (I.Assert e')
    I.CompilerAssert e        ->
      let e' = opt I.TyBool e in
      let go = snoc (I.CompilerAssert e') in
      case e' of
        I.ExpLit (I.LitBool b) ->
          -- It's OK to have false but unreachable compiler asserts.
          if b then blk else go
        _                      -> go
    I.Assume e           ->
      let e' = opt I.TyBool e in
      case e' of
        I.ExpLit (I.LitBool b) ->
          if b then blk
            else error $ "Constant folding evaluated a False assume()"
                       ++ " in evaluating expression " ++ show e
                       ++ " of function " ++ cxt
        _                      -> snoc (I.Assume e')

    I.Return e           -> snoc $ I.Return (typedFold opt e)
    I.ReturnVoid         -> snoc I.ReturnVoid
    I.Deref t var e      -> snoc $ I.Deref t var (opt t e)
    I.Store t e0 e1      -> snoc $ I.Store t (opt t e0) (opt t e1)
    I.Assign t v e       -> snoc $ I.Assign t v (opt t e)
    I.Call t mv c tys    -> snoc $ I.Call t mv c (map (typedFold opt) tys)
    I.Local{}            -> snoc stmt
    I.RefCopy t e0 e1    -> snoc $ I.RefCopy t (opt t e0) (opt t e1)
    I.AllocRef{}         -> snoc stmt
    I.Loop v e incr blk' ->
      let ty = I.TyInt I.Int32 in
      case opt ty e of
        I.ExpLit (I.LitBool b) ->
          if b then error $ "Constant folding evaluated True expression "
                          ++ "in a loop bound.  The loop will never terminate!"
               else error $ "Constant folding evaluated False expression "
                          ++ "in a loop bound.  The loop will never execute!"
        _                      ->
          snoc $ I.Loop v (opt ty e) (loopIncrFold (opt ty) incr)
                        (newFold blk')
    I.Break              -> snoc I.Break
    I.Forever b          -> snoc $ I.Forever (newFold b)
  where sf       = stmtFold cxt opt
        newFold' = foldl' sf D.empty
        newFold  = D.toList . newFold'
        snoc     = (blk `D.snoc`)

loopIncrFold :: (I.Expr -> I.Expr) -> I.LoopIncr -> I.LoopIncr
loopIncrFold opt incr =
  case incr of
    I.IncrTo e0 -> I.IncrTo (opt e0)
    I.DecrTo e0 -> I.DecrTo (opt e0)

typedFold :: ExprOpt -> I.Typed I.Expr -> I.Typed I.Expr
typedFold opt tval@(I.Typed ty val) = tval { I.tValue = opt ty val }

arg0 :: [a] -> a
arg0 = flip (!!) 0

arg1 :: [a] -> a
arg1 = flip (!!) 1

arg2 :: [a] -> a
arg2 = flip (!!) 2

mkArgs :: I.Type -> [I.Expr] -> [I.Expr]
mkArgs ty = map (cf ty)

mkCfArgs :: [I.Expr] -> [CfVal]
mkCfArgs = map toCfVal

mkCfBool :: [I.Expr] -> [Maybe Bool]
mkCfBool = map destBoolLit

-- | Reconstruct an operator, folding away operations when possible.
cfOp :: I.Type -> I.ExpOp -> [I.Expr] -> I.Expr
cfOp ty op args =
  case op of
    I.ExpEq t  -> cfOrd t
    I.ExpNeq t -> cfOrd t
    I.ExpCond
      | Just b <- arg0 goBoolArgs
      -> if b then arg1 (toExpr' ty) else arg2 (toExpr' ty)
      | otherwise -> noop ty
    I.ExpGt orEq t
      | orEq      -> goOrd t gteCheck args
      | otherwise -> goOrd t gtCheck args
    I.ExpLt orEq t
      | orEq      -> goOrd t gteCheck (reverse args)
      | otherwise -> goOrd t gtCheck  (reverse args)
    I.ExpNot
      | Just b <- arg0 goBoolArgs
      -> I.ExpLit (I.LitBool (not b))
      | otherwise -> noop ty
    I.ExpAnd
      | Just lb <- arg0 goBoolArgs
      , Just rb <- arg1 goBoolArgs
      -> I.ExpLit (I.LitBool (lb && rb))
      | Just lb <- arg0 goBoolArgs
      -> if lb then arg1 (toExpr' ty) else I.ExpLit (I.LitBool False)
      | Just rb <- arg1 goBoolArgs
      -> if rb then arg0 (toExpr' ty) else I.ExpLit (I.LitBool False)
      | otherwise -> noop ty
    I.ExpOr
      | Just lb <- arg0 goBoolArgs
      , Just rb <- arg1 goBoolArgs
      -> I.ExpLit (I.LitBool (lb || rb))
      | Just lb <- arg0 goBoolArgs
      -> if lb then I.ExpLit (I.LitBool True) else arg1 (toExpr' ty)
      | Just rb <- arg1 goBoolArgs
      -> if rb then I.ExpLit (I.LitBool True) else arg0 (toExpr' ty)
      | otherwise -> noop ty

    I.ExpMul      -> goNum
    I.ExpAdd      -> goNum
    I.ExpSub      -> goNum
    I.ExpNegate   -> goNum
    I.ExpAbs      -> goNum
    I.ExpSignum   -> goNum

    I.ExpDiv      -> goI2
    I.ExpMod      -> goI2
    I.ExpRecip    -> goF

    I.ExpIsNan t  -> goFB t
    I.ExpIsInf t  -> goFB t

    I.ExpFExp     -> goF
    I.ExpFSqrt    -> goF
    I.ExpFLog     -> goF
    I.ExpFPow     -> goF
    I.ExpFLogBase -> goF
    I.ExpFSin     -> goF
    I.ExpFCos     -> goF
    I.ExpFTan     -> goF
    I.ExpFAsin    -> goF
    I.ExpFAcos    -> goF
    I.ExpFAtan    -> goF
    I.ExpFSinh    -> goF
    I.ExpFCosh    -> goF
    I.ExpFTanh    -> goF
    I.ExpFAsinh   -> goF
    I.ExpFAcosh   -> goF
    I.ExpFAtanh   -> goF

    I.ExpBitAnd        -> toExpr (cfBitAnd ty $ goArgs ty)
    I.ExpBitOr         -> toExpr (cfBitOr ty  $ goArgs ty)

    -- Unimplemented right now
    I.ExpToFloat t     -> noop t
    I.ExpFromFloat t   -> noop t
    I.ExpRoundF        -> noop ty
    I.ExpCeilF         -> noop ty
    I.ExpFloorF        -> noop ty
    I.ExpBitXor        -> noop ty
    I.ExpBitComplement -> noop ty
    I.ExpBitShiftL     -> noop ty
    I.ExpBitShiftR     -> noop ty

  where
  goArgs ty'    = mkCfArgs $ mkArgs ty' args
  toExpr'       = map toExpr . goArgs
  goBoolArgs    = mkCfBool $ mkArgs I.TyBool args
  noop          = I.ExpOp op . map toExpr . goArgs
  goI2          = toExpr (cfIntOp2 ty op $ goArgs ty)
  goF           = toExpr (cfFloating op $ goArgs ty)
  goFB ty'      = toExpr (cfFloatingB op $ goArgs ty')
  cfOrd ty'     = toExpr (cfOrd2 op $ goArgs ty')
  goOrd ty' chk args' =
    let args0 = mkCfArgs $ mkArgs ty' args' in
    fromOrdChecks (cfOrd ty') (chk ty' args0)
  goNum         = toExpr (cfNum ty op $ goArgs ty)


cfBitAnd :: I.Type -> [CfVal] -> CfVal
cfBitAnd ty [l,r]
  | ones ty  l = r
  | ones ty  r = l
  | zeros ty l = CfInteger 0
  | zeros ty r = CfInteger 0
  | otherwise  = CfExpr (I.ExpOp I.ExpBitAnd [toExpr l, toExpr r])
cfBitAnd _ _ = err "Wrong number of args to cfBitAnd in constant folder."

cfBitOr :: I.Type -> [CfVal] -> CfVal
cfBitOr ty [l,r]
  | zeros ty l = r
  | zeros ty r = l
  | ones ty  l = CfInteger 1
  | ones ty  r = CfInteger 1
  | otherwise  = CfExpr (I.ExpOp I.ExpBitOr [toExpr l, toExpr r])
cfBitOr _ _ = err "Wrong number of args to cfBitOr in constant folder."

-- Min values for word types.
zeros :: I.Type -> CfVal -> Bool
zeros I.TyWord{} (CfInteger i) = i == 0
zeros _ _ = False

-- Max values for word types.
ones :: I.Type -> CfVal -> Bool
ones ty (CfInteger i) =
  case ty of
    I.TyWord{} -> maybe False (i ==) (toMaxSize ty)
    _          -> False
ones _ _ = False

-- | Literal expression destructor.
destLit :: I.Expr -> Maybe I.Literal
destLit ex = case ex of
  I.ExpLit lit -> return lit
  _            -> mzero

-- | Boolean literal destructor.
destBoolLit :: I.Expr -> Maybe Bool
destBoolLit ex = do
  I.LitBool b <- destLit ex
  return b

-- | Integer literal destructor.
destIntegerLit :: I.Expr -> Maybe Integer
destIntegerLit ex = do
  I.LitInteger i <- destLit ex
  return i

-- | Float literal destructor.
destFloatLit :: I.Expr -> Maybe Float
destFloatLit ex = do
  I.LitFloat i <- destLit ex
  return i

-- | Double literal destructor.
destDoubleLit :: I.Expr -> Maybe Double
destDoubleLit ex = do
  I.LitDouble i <- destLit ex
  return i

-- Constant-folded Values ------------------------------------------------------

-- | Constant-folded values.
data CfVal
  = CfBool Bool
  | CfInteger Integer
  | CfFloat Float
  | CfDouble Double
  | CfExpr I.Expr
    deriving (Show)

-- | Convert to a constant-folded value.  Picks the one successful lit, if any.
toCfVal :: I.Expr -> CfVal
toCfVal ex = fromMaybe (CfExpr ex) $ msum
  [ CfBool    `fmap` destBoolLit    ex
  , CfInteger `fmap` destIntegerLit ex
  , CfFloat   `fmap` destFloatLit   ex
  , CfDouble  `fmap` destDoubleLit  ex
  ]

-- | Convert back to an expression.
toExpr :: CfVal -> I.Expr
toExpr val = case val of
  CfBool b    -> I.ExpLit (I.LitBool b)
  CfInteger i -> I.ExpLit (I.LitInteger i)
  CfFloat f   -> I.ExpLit (I.LitFloat f)
  CfDouble d  -> I.ExpLit (I.LitDouble d)
  CfExpr ex   -> ex

-- | Check if we're comparing the max or min bound for >= and optimize.
gteCheck :: I.Type -> [CfVal] -> Maybe Bool
gteCheck t [l,r]
  -- forall a. max >= a
  | CfInteger x <- l
  , Just s <- toMaxSize t
  , x == s = Just True
  -- forall a. a >= min
  | CfInteger y <- r
  , Just s <- toMinSize t
  , y == s = Just True
  | otherwise                            = Nothing
gteCheck _ _ = err "wrong number of args to gtCheck."

-- | Check if we're comparing the max or min bound for > and optimize.
gtCheck :: I.Type -> [CfVal] -> Maybe Bool
gtCheck t [l,r]
  -- forall a. not (min > a)
  | CfInteger x <- l
  , Just s <- toMinSize t
  , x == s = Just False
  -- forall a. not (a > max)
  | CfInteger y <- r
  , Just s <- toMaxSize t
  , y == s = Just False
  | otherwise                            = Nothing
gtCheck _ _ = err "wrong number of args to gtCheck."

fromOrdChecks :: I.Expr -> Maybe Bool -> I.Expr
fromOrdChecks expr = maybe expr (toExpr . CfBool)

-- | Apply a binary operation that requires an ord instance.
cfOrd2 :: I.ExpOp
       -> [CfVal]
       -> CfVal
cfOrd2 op [l,r] = case (l,r) of
  (CfBool x,   CfBool y)    -> CfBool (op' x y)
  (CfInteger x,CfInteger y) -> CfBool (op' x y)
  (CfFloat x,  CfFloat y)   -> CfBool (op' x y)
  (CfDouble x, CfDouble y)  -> CfBool (op' x y)
  _                         -> CfExpr (I.ExpOp op [toExpr l, toExpr r])
  where
  op' :: Ord a => a -> a -> Bool
  op' = case op of
    I.ExpEq _     -> (==)
    I.ExpNeq _    -> (/=)
    I.ExpGt orEq _
      | orEq      -> (>=)
      | otherwise -> (>)
    I.ExpLt orEq _
      | orEq      -> (<=)
      | otherwise -> (<)
    _ -> err "bad op to cfOrd2"
cfOrd2 _ _ = err "wrong number of args to cfOrd2"

--------------------------------------------------------------------------------

class Integral a => IntegralOp a where
  appI1 :: (a -> a) -> a -> CfVal
  appI1 op x = CfInteger $ toInteger $ op x

  appI2 :: (a -> a -> a) -> a -> a -> CfVal
  appI2 op x y = CfInteger $ toInteger $ op x y

instance IntegralOp Int8
instance IntegralOp Int16
instance IntegralOp Int32
instance IntegralOp Int64
instance IntegralOp Word8
instance IntegralOp Word16
instance IntegralOp Word32
instance IntegralOp Word64

--------------------------------------------------------------------------------

cfNum :: I.Type
      -> I.ExpOp
      -> [CfVal]
      -> CfVal
cfNum ty op args = case args of
  [x]   -> case x of
    CfInteger l -> case ty of
      I.TyInt isz -> case isz of
        I.Int8        -> appI1 op1 (fromInteger l :: Int8)
        I.Int16       -> appI1 op1 (fromInteger l :: Int16)
        I.Int32       -> appI1 op1 (fromInteger l :: Int32)
        I.Int64       -> appI1 op1 (fromInteger l :: Int64)
      I.TyWord isz -> case isz of
        I.Word8       -> appI1 op1 (fromInteger l :: Word8)
        I.Word16      -> appI1 op1 (fromInteger l :: Word16)
        I.Word32      -> appI1 op1 (fromInteger l :: Word32)
        I.Word64      -> appI1 op1 (fromInteger l :: Word64)
      _ -> err $ "bad type to cfNum loc 1 "
    CfFloat   l -> CfFloat   (op1 l)
    CfDouble  l -> CfDouble  (op1 l)
    _           -> CfExpr    (I.ExpOp op [toExpr x])

  [x,y] -> case (x,y) of
    (CfInteger l, CfInteger r) -> case ty of
      I.TyInt isz -> case isz of
        I.Int8        -> appI2 op2 (fromInteger l :: Int8)
                                   (fromInteger r :: Int8)
        I.Int16       -> appI2 op2 (fromInteger l :: Int16)
                                   (fromInteger r :: Int16)
        I.Int32       -> appI2 op2 (fromInteger l :: Int32)
                                   (fromInteger r :: Int32)
        I.Int64       -> appI2 op2 (fromInteger l :: Int64)
                                   (fromInteger r :: Int64)
      I.TyWord isz -> case isz of
        I.Word8       -> appI2 op2 (fromInteger l :: Word8)
                                   (fromInteger r :: Word8)
        I.Word16      -> appI2 op2 (fromInteger l :: Word16)
                                   (fromInteger r :: Word16)
        I.Word32      -> appI2 op2 (fromInteger l :: Word32)
                                   (fromInteger r :: Word32)
        I.Word64      -> appI2 op2 (fromInteger l :: Word64)
                                   (fromInteger r :: Word64)
      _ -> err "bad type to cfNum loc 2"
    (CfFloat   l, CfFloat r)   -> CfFloat   (op2 l r)
    (CfDouble l,  CfDouble r)  -> CfDouble  (op2 l r)
    _                          -> CfExpr    (I.ExpOp op [toExpr x, toExpr y])

  _ -> err "wrong num args to cfNum"
  where
  op2 :: Num a => a -> a -> a
  op2 = case op of
    I.ExpMul    -> (*)
    I.ExpAdd    -> (+)
    I.ExpSub    -> (-)
    _ -> err "bad op to cfNum loc 3"
  op1 :: Num a => a -> a
  op1 = case op of
    I.ExpNegate -> negate
    I.ExpAbs    -> abs
    I.ExpSignum -> signum
    _ -> err "bad op to cfNum loc 4"

cfIntOp2 :: I.Type -> I.ExpOp -> [CfVal] -> CfVal
cfIntOp2 ty iOp [CfInteger l, CfInteger r] = case ty of
  I.TyInt isz -> case isz of
    I.Int8        -> appI2 op2 (fromInteger l :: Int8)
                               (fromInteger r :: Int8)
    I.Int16       -> appI2 op2 (fromInteger l :: Int16)
                               (fromInteger r :: Int16)
    I.Int32       -> appI2 op2 (fromInteger l :: Int32)
                               (fromInteger r :: Int32)
    I.Int64       -> appI2 op2 (fromInteger l :: Int64)
                               (fromInteger r :: Int64)
  I.TyWord isz -> case isz of
    I.Word8       -> appI2 op2 (fromInteger l :: Word8)
                               (fromInteger r :: Word8)
    I.Word16      -> appI2 op2 (fromInteger l :: Word16)
                               (fromInteger r :: Word16)
    I.Word32      -> appI2 op2 (fromInteger l :: Word32)
                               (fromInteger r :: Word32)
    I.Word64      -> appI2 op2 (fromInteger l :: Word64)
                               (fromInteger r :: Word64)
  _ -> err "bad type to cfIntOp2 loc 1"

  where
  op2 :: Integral a => a -> a -> a
  op2 = case iOp of
    I.ExpDiv -> quot
    -- Haskell's `rem` matches C ISO 1999 semantics of the remainder having the
    -- same sign as the dividend.
    I.ExpMod -> rem
    _ -> err "bad op to cfIntOp2"

cfIntOp2 _ iOp [x, y] = CfExpr (I.ExpOp iOp [toExpr x, toExpr y])
cfIntOp2 _ _ _        = err "wrong number of args to cfOp2"

--------------------------------------------------------------------------------

-- | Constant folding for unary operations that require a floating instance.
cfFloating :: I.ExpOp
           -> [CfVal]
           -> CfVal
cfFloating op args = case args of
  [x]   -> case x of
             CfFloat f  -> CfFloat  (op1 f)
             CfDouble d -> CfDouble (op1 d)
             _          -> CfExpr   (I.ExpOp op [toExpr x])
  [x,y] -> case (x,y) of
             (CfFloat l,  CfFloat r)  -> CfFloat (op2 l r)
             (CfDouble l, CfDouble r) -> CfDouble (op2 l r)
             _                        -> CfExpr   (I.ExpOp op [toExpr x
                                                              , toExpr y])
  _     -> err "wrong number of args to cfFloating"
  where
  op1 :: Floating a => a -> a
  op1 = case op of
    I.ExpRecip   -> recip
    I.ExpFExp    -> exp
    I.ExpFSqrt   -> sqrt
    I.ExpFLog    -> log
    I.ExpFSin    -> sin
    I.ExpFCos    -> cos
    I.ExpFTan    -> tan
    I.ExpFAsin   -> asin
    I.ExpFAcos   -> acos
    I.ExpFAtan   -> atan
    I.ExpFSinh   -> sinh
    I.ExpFCosh   -> cosh
    I.ExpFTanh   -> tanh
    I.ExpFAsinh  -> asinh
    I.ExpFAcosh  -> acosh
    I.ExpFAtanh  -> atanh
    _            -> err "wrong op1 to cfFloating"

  op2 :: Floating a => a -> a -> a
  op2 = case op of
    I.ExpFPow     -> (**)
    I.ExpFLogBase -> logBase
    _            -> err "wrong op2 to cfFloating"

cfFloatingB :: I.ExpOp
            -> [CfVal]
            -> CfVal
cfFloatingB op [x] = case x of
  CfFloat f  -> CfBool  (op' f)
  CfDouble d -> CfBool  (op' d)
  _          -> CfExpr  (I.ExpOp op [toExpr x])
  where
  op' :: RealFloat a => a -> Bool
  op' = case op of
    I.ExpIsNan _ -> isNaN
    I.ExpIsInf _ -> isInfinite
    _            -> err "wrong op to cfFloatingB"
cfFloatingB _ _ = err "wrong number of args to cfFloatingB"

--------------------------------------------------------------------------------

err :: String -> a
err msg = error $ "Ivory-Opts internal error: " ++ msg