{-# LANGUAGE ViewPatterns #-}

--
-- Constant folding computations in Haskell.
--
-- Copyright (C) 2014, Galois, Inc.
-- All rights reserved.
--

module Ivory.Opts.ConstFoldComp
  ( CfVal(..)
  , err
  , toExpr
  , destLit
  , destBoolLit
  , destIntegerLit
  , isLitValue
  , mkCfArgs
  , cfNum
  , cfBitAnd
  , cfBitOr
  , cfFloating
  , cfFloatingB
  , cfIntOp2
  ) where

import qualified Ivory.Language.Syntax.AST as I
import qualified Ivory.Language.Syntax.Type as I
import Ivory.Language.Cast (toMaxSize, toMinSize)

import Control.Monad (mzero,msum)
import Data.Bits
import Data.Maybe

import Data.Word
import Data.Int

--------------------------------------------------------------------------------

-- | Constant-folded values.
data CfVal
  = CfBool Bool
  -- Is this a max or min value for the size type?
  | CfInteger MaxMin Integer
  | CfFloat Float
  | CfDouble Double
  | CfExpr I.Expr
    deriving (Show, Eq)

-- | Convert back to an expression.
toExpr :: CfVal -> I.Expr
toExpr val = case val of
  CfBool b      -> I.ExpLit (I.LitBool b)
  CfInteger m i -> case m of
    Min  -> I.ExpMaxMin False
    Max  -> I.ExpMaxMin True
    None -> I.ExpLit (I.LitInteger i)
  CfFloat f     -> I.ExpLit (I.LitFloat f)
  CfDouble d    -> I.ExpLit (I.LitDouble d)
  CfExpr ex     -> ex

--------------------------------------------------------------------------------

-- | Whether the bounded integer represents a max or min value for its size.
data MaxMin = Max | Min | None deriving (Show, Read, Eq)

isMaxMin :: I.Type -> Integer -> MaxMin
isMaxMin ty i
  | Just m <- toMaxSize ty
  , m == i
  = Max
  | Just m <- toMinSize ty
  , m == i
  = Min
  | otherwise
  = None

toMaxMin :: (Eq a, Bounded a) => a -> MaxMin
toMaxMin r | r == maxBound = Max
           | r == minBound = Min
           | otherwise     = None

--------------------------------------------------------------------------------

mkCfArgs :: I.Type -> [I.Expr] -> [CfVal]
mkCfArgs ty exps = map toCfVal exps
  where
  -- | Convert to a constant-folded value.  Picks the one successful lit, if any.
  toCfVal :: I.Expr -> CfVal
  toCfVal ex = fromMaybe (CfExpr ex) $ msum
    [ CfBool    `fmap` destBoolLit       ex
    , CfFloat   `fmap` destFloatLit      ex
    , CfDouble  `fmap` destDoubleLit     ex
    , (uncurry CfInteger) `fmap` (destMinMaxIntegerLit ex)
    ]

  -- | Minimum, maximum, or integer value.
  destMinMaxIntegerLit :: I.Expr -> Maybe (MaxMin, Integer)
  destMinMaxIntegerLit ex = case ex of
    I.ExpMaxMin True          -> do s <- toMaxSize ty
                                    return (Max, s)
    I.ExpMaxMin False         -> do s <- toMinSize ty
                                    return (Min, s)
    I.ExpLit (I.LitInteger i) -> Just (isMaxMin ty i, i)
    _                         -> Nothing


cfBitAnd :: I.Type -> [CfVal] -> CfVal
cfBitAnd ty [l, r] = case (ty, l, r) of
  (I.TyWord _, CfInteger Min _, _) -> l
  (I.TyWord _, CfInteger Max _, _) -> r
  (I.TyWord _, _, CfInteger Min _) -> r
  (I.TyWord _, _, CfInteger Max _) -> l
  _ -> abc (combineBits (.&.)) ty I.ExpBitAnd l r
cfBitAnd _ _ = err "Wrong number of args to cfBitAnd in constant folder."

cfBitOr :: I.Type -> [CfVal] -> CfVal
cfBitOr ty [l, r] = case (ty, l, r) of
  (I.TyWord _, CfInteger Min _, _) -> r
  (I.TyWord _, CfInteger Max _, _) -> l
  (I.TyWord _, _, CfInteger Min _) -> l
  (I.TyWord _, _, CfInteger Max _) -> r
  _ -> abc (combineBits (.|.)) ty I.ExpBitOr l r
cfBitOr _ _ = err "Wrong number of args to cfBitOr in constant folder."


combineBits :: (Integer -> Integer -> Integer) -> I.ExpOp -> CfVal -> CfVal -> CfVal
combineBits f _ (CfInteger _ x) (CfInteger _ y) = CfInteger None $ f x y
combineBits _ op x y = CfExpr $ I.ExpOp op [toExpr x, toExpr y]

--------------------------------------------------------------------------------

----------------------------------------
-- Gather constants from an associative/commutative tree of operators

{-
Rules for normalizing constants in Associative Binary Commutative operators:

op [const, const]: evaluate the op. (establishes that each op has at least one non-const child)
op [const, var] -> cf op [var, const] (allowed by commutativity; establishes that consts are only right-children)
op [a, op [b, c]] -> cf op [cf op [a, b], c] (allowed by associativity; establishes that right child is not this op, so can't contain more constants)
op [op [var, const], const] -> op [var, cf op [const, const]] (allowed by associativity; establishes that if right-child is const, left-child does not contain any constants)
op [op [var1, const], var2] -> op [op [var1, var2], const] (allowed by associativity and commutativity; establishes that left-child does not contain constants; note that var1 and var2 can't contain constants by these rules)
anything else: unchanged

These rules assume that the operands have already had these rules
applied bottom-up, and avoid re-doing any work in subtrees that haven't
changed.
-}

abc :: (I.ExpOp -> CfVal -> CfVal -> CfVal) -> I.Type -> I.ExpOp -> CfVal -> CfVal -> CfVal
abc combine ty op (CfExpr lhs) rhs = case (lhs, rhs) of
  (_, CfExpr (I.ExpOp op' (mkCfArgs ty -> [b, c]))) | op == op' -> abc combine ty op (abc combine ty op (CfExpr lhs) b) c
  (I.ExpOp _ (_ : (mkCfArgs ty -> [CfExpr _])), _) -> noop
  (I.ExpOp op' [a, b], CfExpr c) | op == op' -> CfExpr (I.ExpOp op [I.ExpOp op [a, c], b])
  (I.ExpOp op' (a : (mkCfArgs ty -> [b])), c) | op == op' -> CfExpr (I.ExpOp op [a, toExpr $ combine op b c])
  _ -> noop
  where
  noop = CfExpr (I.ExpOp op [lhs, toExpr rhs])
abc combine ty op lhs rhs@(CfExpr _) = abc combine ty op rhs lhs
abc combine _ op lhs rhs = combine op lhs rhs

--------------------------------------------------------------------------------

----------------------------------------
-- Constant folded Haskell literals

-- | Literal expression destructor.
destLit :: I.Expr -> Maybe I.Literal
destLit ex = case ex of
  I.ExpLit lit -> return lit
  _            -> mzero

-- | Boolean literal destructor.
destBoolLit :: I.Expr -> Maybe Bool
destBoolLit ex = do
  I.LitBool b <- destLit ex
  return b

-- | Integer literal destructor.
destIntegerLit :: I.Expr -> Maybe Integer
destIntegerLit ex = do
  I.LitInteger i <- destLit ex
  return i

-- | Float literal destructor.
destFloatLit :: I.Expr -> Maybe Float
destFloatLit ex = do
  I.LitFloat i <- destLit ex
  return i

-- | Double literal destructor.
destDoubleLit :: I.Expr -> Maybe Double
destDoubleLit ex = do
  I.LitDouble i <- destLit ex
  return i

isLitValue :: Integer -> CfVal -> Bool
isLitValue v (CfInteger _ v') = v == v'
isLitValue v (CfFloat v') = fromInteger v == v'
isLitValue v (CfDouble v') = fromInteger v == v'
isLitValue _ _ = False
----------------------------------------


class (Bounded a, Integral a) => IntegralOp a where
  appI1 :: (a -> a) -> a -> CfVal
  appI1 op x = let r = op x in
               CfInteger (toMaxMin r) (toInteger r)

  appI2 :: (a -> a -> a) -> a -> a -> CfVal
  appI2 op x y = let r = op x y in
                 CfInteger (toMaxMin r) (toInteger r)

instance IntegralOp Int8
instance IntegralOp Int16
instance IntegralOp Int32
instance IntegralOp Int64
instance IntegralOp Word8
instance IntegralOp Word16
instance IntegralOp Word32
instance IntegralOp Word64

--------------------------------------------------------------------------------

cfNum :: I.Type
      -> I.ExpOp
      -> [CfVal]
      -> CfVal
cfNum ty op args = case args of
  [x]   -> case x of
    CfInteger _ l -> case ty of
      I.TyInt isz -> case isz of
        I.Int8        -> appI1 op1 (fromInteger l :: Int8)
        I.Int16       -> appI1 op1 (fromInteger l :: Int16)
        I.Int32       -> appI1 op1 (fromInteger l :: Int32)
        I.Int64       -> appI1 op1 (fromInteger l :: Int64)
      I.TyWord isz -> case isz of
        I.Word8       -> appI1 op1 (fromInteger l :: Word8)
        I.Word16      -> appI1 op1 (fromInteger l :: Word16)
        I.Word32      -> appI1 op1 (fromInteger l :: Word32)
        I.Word64      -> appI1 op1 (fromInteger l :: Word64)
      I.TyIndex _n    -> appI1 op1 (fromInteger l :: Int32)
      _ -> err $ "bad type to cfNum loc 1 "
    CfFloat   l -> CfFloat   (op1 l)
    CfDouble  l -> CfDouble  (op1 l)
    _           -> CfExpr    (I.ExpOp op [toExpr x])

  [x,y] -> case (x,y) of
    (CfInteger _ l, CfInteger _ r) -> case ty of
      I.TyInt isz -> case isz of
        I.Int8        -> appI2 op2 (fromInteger l :: Int8)
                                   (fromInteger r :: Int8)
        I.Int16       -> appI2 op2 (fromInteger l :: Int16)
                                   (fromInteger r :: Int16)
        I.Int32       -> appI2 op2 (fromInteger l :: Int32)
                                   (fromInteger r :: Int32)
        I.Int64       -> appI2 op2 (fromInteger l :: Int64)
                                   (fromInteger r :: Int64)
      I.TyWord isz -> case isz of
        I.Word8       -> appI2 op2 (fromInteger l :: Word8)
                                   (fromInteger r :: Word8)
        I.Word16      -> appI2 op2 (fromInteger l :: Word16)
                                   (fromInteger r :: Word16)
        I.Word32      -> appI2 op2 (fromInteger l :: Word32)
                                   (fromInteger r :: Word32)
        I.Word64      -> appI2 op2 (fromInteger l :: Word64)
                                   (fromInteger r :: Word64)
      I.TyIndex _n    -> appI2 op2 (fromInteger l :: Int32)
                                   (fromInteger r :: Int32)
      _ -> err "bad type to cfNum loc 2"
    (CfFloat   l, CfFloat r)   -> CfFloat   (op2 l r)
    (CfDouble l,  CfDouble r)  -> CfDouble  (op2 l r)
    _                          -> CfExpr    (I.ExpOp op [toExpr x, toExpr y])

  _ -> err "wrong num args to cfNum"
  where
  op2 :: Num a => a -> a -> a
  op2 = case op of
    I.ExpMul    -> (*)
    I.ExpAdd    -> (+)
    I.ExpSub    -> (-)
    _ -> err "bad op to cfNum loc 3"
  op1 :: Num a => a -> a
  op1 = case op of
    I.ExpNegate -> negate
    I.ExpAbs    -> abs
    I.ExpSignum -> signum
    _ -> err "bad op to cfNum loc 4"

cfIntOp2 :: I.Type -> I.ExpOp -> [CfVal] -> CfVal
cfIntOp2 ty iOp [CfInteger _ l, CfInteger _ r] = case ty of
  I.TyInt isz -> case isz of
    I.Int8        -> appI2 op2 (fromInteger l :: Int8)
                               (fromInteger r :: Int8)
    I.Int16       -> appI2 op2 (fromInteger l :: Int16)
                               (fromInteger r :: Int16)
    I.Int32       -> appI2 op2 (fromInteger l :: Int32)
                               (fromInteger r :: Int32)
    I.Int64       -> appI2 op2 (fromInteger l :: Int64)
                               (fromInteger r :: Int64)
  I.TyWord isz -> case isz of
    I.Word8       -> appI2 op2 (fromInteger l :: Word8)
                               (fromInteger r :: Word8)
    I.Word16      -> appI2 op2 (fromInteger l :: Word16)
                               (fromInteger r :: Word16)
    I.Word32      -> appI2 op2 (fromInteger l :: Word32)
                               (fromInteger r :: Word32)
    I.Word64      -> appI2 op2 (fromInteger l :: Word64)
                               (fromInteger r :: Word64)
  I.TyIndex _n    -> appI2 op2 (fromInteger l :: Int32)
                               (fromInteger r :: Int32)
  _ -> err "bad type to cfIntOp2 loc 1"

  where
  op2 :: Integral a => a -> a -> a
  op2 = case iOp of
    I.ExpDiv -> quot
    -- Haskell's `rem` matches C ISO 1999 semantics of the remainder having the
    -- same sign as the dividend.
    I.ExpMod -> rem
    _ -> err "bad op to cfIntOp2"

cfIntOp2 _ iOp [x, y] = CfExpr (I.ExpOp iOp [toExpr x, toExpr y])
cfIntOp2 _ _ _        = err "wrong number of args to cfOp2"

--------------------------------------------------------------------------------

-- | Constant folding for unary operations that require a floating instance.
cfFloating :: I.ExpOp
           -> [CfVal]
           -> CfVal
cfFloating op args = case args of
  [x]   -> case x of
             CfFloat f  -> CfFloat  (op1 f)
             CfDouble d -> CfDouble (op1 d)
             _          -> CfExpr   (I.ExpOp op [toExpr x])
  [x,y] -> case (x,y) of
             (CfFloat l,  CfFloat r)  -> CfFloat (op2 l r)
             (CfDouble l, CfDouble r) -> CfDouble (op2 l r)
             _                        -> CfExpr   (I.ExpOp op [toExpr x
                                                              , toExpr y])
  _     -> err "wrong number of args to cfFloating"
  where
  op1 :: Floating a => a -> a
  op1 = case op of
    I.ExpRecip   -> recip
    I.ExpFExp    -> exp
    I.ExpFSqrt   -> sqrt
    I.ExpFLog    -> log
    I.ExpFSin    -> sin
    I.ExpFCos    -> cos
    I.ExpFTan    -> tan
    I.ExpFAsin   -> asin
    I.ExpFAcos   -> acos
    I.ExpFAtan   -> atan
    I.ExpFSinh   -> sinh
    I.ExpFCosh   -> cosh
    I.ExpFTanh   -> tanh
    I.ExpFAsinh  -> asinh
    I.ExpFAcosh  -> acosh
    I.ExpFAtanh  -> atanh
    _            -> err "wrong op1 to cfFloating"

  op2 :: RealFloat a => a -> a -> a
  op2 = case op of
    I.ExpFPow     -> (**)
    I.ExpFLogBase -> logBase
    I.ExpFAtan2   -> atan2
    _            -> err "wrong op2 to cfFloating"

cfFloatingB :: I.ExpOp
            -> [CfVal]
            -> CfVal
cfFloatingB op [x] = case x of
  CfFloat f  -> CfBool  (op' f)
  CfDouble d -> CfBool  (op' d)
  _          -> CfExpr  (I.ExpOp op [toExpr x])
  where
  op' :: RealFloat a => a -> Bool
  op' = case op of
    I.ExpIsNan _ -> isNaN
    I.ExpIsInf _ -> isInfinite
    _            -> err "wrong op to cfFloatingB"
cfFloatingB _ _ = err "wrong number of args to cfFloatingB"

--------------------------------------------------------------------------------

err :: String -> a
err msg = error $ "Ivory-Opts internal error: " ++ msg