{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DeriveFoldable       #-}
{-# LANGUAGE DeriveFunctor        #-}
{-# LANGUAGE DeriveTraversable    #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}

module Ivory.Opts.CSE (cseFold) where

import           Prelude               ()
import           Prelude.Compat

import           Control.Applicative   (liftA2)
import qualified Data.DList            as D
import           Data.IntMap.Strict    (IntMap)
import qualified Data.IntMap.Strict    as IntMap
import           Data.IntSet           (IntSet)
import qualified Data.IntSet           as IntSet
import           Data.List             (sort)
import           Data.Map.Strict       (Map)
import qualified Data.Map.Strict       as Map
import           Data.Reify
import           Ivory.Language.Array  (ixRep)
import qualified Ivory.Language.Syntax as AST
import           MonadLib              (Id, StateT, WriterT, collect, get, lift,
                                        put, runM, set, sets, sets_)
import           System.IO.Unsafe      (unsafePerformIO)

-- | Find each common sub-expression and extract it to a new variable,
-- making any sharing explicit. However, this function should never move
-- evaluation of an expression earlier than it would have occurred in
-- the source program, which means that sometimes an expression must be
-- re-computed on each of several execution paths.
cseFold :: AST.Proc -> AST.Proc
cseFold def = def
  { AST.procBody = reconstruct $ unsafePerformIO $ reifyGraph $ AST.procBody def }

-- | Variable assignments emitted so far.
data Bindings = Bindings
  { availableBindings :: (Map (Unique, AST.Type) Int)
  , unusedBindings    :: IntSet
  , totalBindings     :: Int
  }

-- | A monad for emitting both source-level statements as well as
-- assignments that capture common subexpressions.
--
-- Note that the StateT is outside the WriterT so that we can first run
-- the StateT, getting a set of expressions which shouldn't be assigned
-- to fresh names, and only then decide whether to write out Assign
-- statements. See the comment in `updateFacts`.
type BlockM a = StateT Bindings (WriterT (D.DList AST.Stmt) Id) a

-- | We perform CSE on expressions but also across all the blocks in a
-- procedure.
data CSE t
  = CSEExpr (ExprF t)
  | CSEBlock (BlockF t)
  deriving (Show, Eq, Ord, Functor)

-- | During CSE, we replace recursive references to an expression with a
-- unique ID for that expression.
data ExprF t
  = ExpSimpleF AST.Expr
    -- ^ For expressions that cannot contain any expressions recursively.
  | ExpLabelF AST.Type t String
  | ExpIndexF AST.Type t AST.Type t
  | ExpToIxF t Integer
  | ExpSafeCastF AST.Type t
  | ExpOpF AST.ExpOp [t]
  deriving (Show, Eq, Ord, Foldable, Functor, Traversable)

instance MuRef AST.Expr where
  type DeRef AST.Expr = CSE
  mapDeRef child e = CSEExpr <$> case e of
    AST.ExpSym{}    -> pure $ ExpSimpleF e
    AST.ExpExtern{} -> pure $ ExpSimpleF e
    AST.ExpVar{}    -> pure $ ExpSimpleF e
    AST.ExpLit{}    -> pure $ ExpSimpleF e
    AST.ExpLabel ty ex nm -> ExpLabelF <$> pure ty <*> child ex <*> pure nm
    AST.ExpIndex ty1 ex1 ty2 ex2 -> ExpIndexF <$> pure ty1 <*> child ex1 <*> pure ty2 <*> child ex2
    AST.ExpToIx ex bound -> ExpToIxF <$> child ex <*> pure bound
    AST.ExpSafeCast ty ex -> ExpSafeCastF ty <$> child ex
    AST.ExpOp op args -> ExpOpF op <$> traverse child args
    AST.ExpAddrOfGlobal{} -> pure $ ExpSimpleF e
    AST.ExpMaxMin{} -> pure $ ExpSimpleF e
    AST.ExpSizeOf{} -> pure $ ExpSimpleF e

-- | Convert a flattened expression back to a real expression.
toExpr :: ExprF AST.Expr -> AST.Expr
toExpr (ExpSimpleF ex) = ex
toExpr (ExpLabelF ty ex nm) = AST.ExpLabel ty ex nm
toExpr (ExpIndexF ty1 ex1 ty2 ex2) = AST.ExpIndex ty1 ex1 ty2 ex2
toExpr (ExpToIxF ex bound) = AST.ExpToIx ex bound
toExpr (ExpSafeCastF ty ex) = AST.ExpSafeCast ty ex
toExpr (ExpOpF op args) = AST.ExpOp op args

-- | Wrap the second type in either TyRef or TyConstRef, according to
-- whether the first argument was a constant ref.
copyConst :: AST.Type -> AST.Type -> AST.Type
copyConst (AST.TyRef _) ty = AST.TyRef ty
copyConst (AST.TyConstRef _) ty = AST.TyConstRef ty
copyConst ty _ = error $ "Ivory.Opts.CSE.copyConst: expected a Ref type but got " ++ show ty

-- | Label all sub-expressions with the type at which they're used,
-- assuming that this expression is used at the given type.
labelTypes :: AST.Type -> ExprF k -> ExprF (k, AST.Type)
labelTypes _ (ExpSimpleF e) = ExpSimpleF e
labelTypes resty (ExpLabelF ty ex nm) = ExpLabelF ty (ex, copyConst resty ty) nm
labelTypes resty (ExpIndexF ty1 ex1 ty2 ex2) = ExpIndexF ty1 (ex1, copyConst resty ty1) ty2 (ex2, ty2)
labelTypes _ (ExpToIxF ex bd) = ExpToIxF (ex, ixRep) bd
labelTypes _ (ExpSafeCastF ty ex) = ExpSafeCastF ty (ex, ty)
labelTypes ty (ExpOpF op args) = ExpOpF op $ case op of
  AST.ExpEq t -> map (`atType` t) args
  AST.ExpNeq t -> map (`atType` t) args
  AST.ExpCond -> let (cond, rest) = splitAt 1 args in map (`atType` AST.TyBool) cond ++ map (`atType` ty) rest
  AST.ExpGt _ t -> map (`atType` t) args
  AST.ExpLt _ t -> map (`atType` t) args
  AST.ExpIsNan t  -> map (`atType` t) args
  AST.ExpIsInf t  -> map (`atType` t) args
  _ -> map (`atType` ty) args
  where
  atType = (,)

-- | Like ExprF, we replace recursive references to
-- blocks/statements/expressions with unique IDs.
--
-- Note that we treat statements as a kind of block, because extracting
-- assignments for the common subexpressions in a statement can result
-- in multiple statements, which looks much like a block.
data BlockF t
  = StmtSimple AST.Stmt
    -- ^ For statements that cannot contain any other statements or expressions.
  | StmtIfTE t t t
  | StmtAssert t
  | StmtCompilerAssert t
  | StmtAssume t
  | StmtReturn (AST.Typed t)
  | StmtDeref AST.Type AST.Var t
  | StmtStore AST.Type t t
  | StmtAssign AST.Type AST.Var t
  | StmtCall AST.Type (Maybe AST.Var) AST.Name [AST.Typed t]
  | StmtLocal AST.Type AST.Var (InitF t)
  | StmtRefCopy AST.Type t t
  | StmtRefZero AST.Type t
  | StmtLoop Integer AST.Var t (LoopIncrF t) t
  | StmtForever t
  | Block [t]
  deriving (Show, Eq, Ord, Functor)

data LoopIncrF t
  = IncrTo t
  | DecrTo t
  deriving (Show, Eq, Ord, Functor)

data InitF t
  = InitZero
  | InitExpr AST.Type t
  | InitStruct [(String, InitF t)]
  | InitArray [InitF t] Bool
  deriving (Show, Eq, Ord, Functor)

instance MuRef AST.Stmt where
  type DeRef AST.Stmt = CSE
  mapDeRef child stmt = CSEBlock <$> case stmt of
    AST.IfTE cond tb fb -> StmtIfTE <$> child cond <*> child tb <*> child fb
    AST.Assert cond -> StmtAssert <$> child cond
    AST.CompilerAssert cond -> StmtCompilerAssert <$> child cond
    AST.Assume cond -> StmtAssume <$> child cond
    AST.Return (AST.Typed ty ex) -> StmtReturn <$> (AST.Typed ty <$> child ex)
    AST.Deref ty var ex -> StmtDeref ty var <$> child ex
    AST.Store ty lhs rhs -> StmtStore ty <$> child lhs <*> child rhs
    AST.Assign ty var ex -> StmtAssign ty var <$> child ex
    AST.Call ty mv nm args -> StmtCall ty mv nm <$> traverse (\ (AST.Typed argTy argEx) -> AST.Typed argTy <$> child argEx) args
    AST.Local ty var initex -> StmtLocal ty var <$> mapInit initex
    AST.RefCopy ty dst src -> StmtRefCopy ty <$> child dst <*> child src
    AST.RefZero ty dst -> StmtRefZero ty <$> child dst
    AST.Loop m var ex incr lb -> StmtLoop m var <$> child ex <*> mapIncr incr <*> child lb
    AST.Forever lb -> StmtForever <$> child lb
    -- These kinds of statements can't contain other statements or expressions.
    AST.ReturnVoid -> pure $ StmtSimple stmt
    AST.AllocRef{} -> pure $ StmtSimple stmt
    AST.Break -> pure $ StmtSimple stmt
    AST.Comment{} -> pure $ StmtSimple stmt
    where
    mapInit AST.InitZero = pure InitZero
    mapInit (AST.InitExpr ty ex) = InitExpr ty <$> child ex
    mapInit (AST.InitStruct fields) = InitStruct <$> traverse (\ (nm, i) -> (,) nm <$> mapInit i) fields
    mapInit (AST.InitArray elements b) = liftA2 InitArray (traverse mapInit elements) (pure b)
    mapIncr (AST.IncrTo ex) = IncrTo <$> child ex
    mapIncr (AST.DecrTo ex) = DecrTo <$> child ex

instance (MuRef a, DeRef [a] ~ DeRef a) => MuRef [a] where
  type DeRef [a] = CSE
  mapDeRef child xs = CSEBlock <$> Block <$> traverse child xs

-- | Convert a flattened statement or block back to a real block.
toBlock :: (k -> AST.Type -> BlockM AST.Expr) -> (k -> BlockM ()) -> BlockF k -> BlockM ()
toBlock expr block b = case b of
  StmtSimple s -> stmt $ return s
  StmtIfTE ex tb fb -> stmt $ AST.IfTE <$> expr ex AST.TyBool <*> genBlock (block tb) <*> genBlock (block fb)
  StmtAssert cond -> stmt $ AST.Assert <$> expr cond AST.TyBool
  StmtCompilerAssert cond -> stmt $ AST.CompilerAssert <$> expr cond AST.TyBool
  StmtAssume cond -> stmt $ AST.Assume <$> expr cond AST.TyBool
  StmtReturn (AST.Typed ty ex) -> stmt $ AST.Return <$> (AST.Typed ty <$> expr ex ty)
  -- XXX: The AST does not preserve whether the RHS of a deref was for a
  -- const ref, but it's safe to assume it's const.
  StmtDeref ty var ex -> stmt $ AST.Deref ty var <$> expr ex (AST.TyConstRef ty)
  -- XXX: The LHS of a store must not have been const.
  StmtStore ty lhs rhs -> stmt $ AST.Store ty <$> expr lhs (AST.TyRef ty) <*> expr rhs ty
  StmtAssign ty var ex -> stmt $ AST.Assign ty var <$> expr ex ty
  StmtCall ty mv nm args -> stmt $ AST.Call ty mv nm <$> mapM (\ (AST.Typed argTy argEx) -> AST.Typed argTy <$> expr argEx argTy) args
  StmtLocal ty var initex -> stmt $ AST.Local ty var <$> toInit initex
  -- XXX: See deref and store comments above.
  StmtRefCopy ty dst src -> stmt $ AST.RefCopy ty <$> expr dst (AST.TyRef ty) <*> expr src (AST.TyConstRef ty)
  StmtRefZero ty dst -> stmt $ AST.RefZero ty <$> expr dst (AST.TyRef ty)
  StmtLoop m var ex incr lb -> stmt $ AST.Loop m var <$> expr ex ixRep <*> toIncr incr <*> genBlock (block lb)
  StmtForever lb -> stmt $ AST.Forever <$> genBlock (block lb)
  Block stmts -> mapM_ block stmts
  where
  stmt stmtM = fmap D.singleton stmtM >>= put
  toInit InitZero = pure AST.InitZero
  toInit (InitExpr ty ex) = AST.InitExpr ty <$> expr ex ty
  toInit (InitStruct fields) = AST.InitStruct <$> traverse (\ (nm, i) -> (,) nm <$> toInit i) fields
  toInit (InitArray elements b') = liftA2 AST.InitArray (traverse toInit elements) (pure b')
  toIncr (IncrTo ex) = AST.IncrTo <$> expr ex ixRep
  toIncr (DecrTo ex) = AST.DecrTo <$> expr ex ixRep

-- | When a statement contains a block, we need to propagate the
-- available expressions into that block. However, on exit from that
-- block, the expressions it made newly-available go out of scope, so we
-- remove them from the available set for subsequent statements.
genBlock :: BlockM () -> BlockM AST.Block
genBlock gen = do
  oldBindings <- get
  ((), stmts) <- collect gen
  sets_ $ \ newBindings -> newBindings { availableBindings = availableBindings oldBindings }
  return $ D.toList stmts

-- | Data to accumulate as we analyze each expression and each
-- block/statement.
type Facts = (IntMap (AST.Type -> BlockM AST.Expr), IntMap (BlockM ()))

-- | We can only generate code from a DAG, so this function calls
-- `error` if the reified graph has cycles. Because we walk the AST in
-- topo-sorted order, if we haven't already computed the desired fact,
-- then we're trying to follow a back-edge in the graph, and that means
-- the graph has cycles.
getFact :: IntMap v -> Unique -> v
getFact m k = case IntMap.lookup k m of
  Nothing -> error "IvoryCSE: cycle detected in expression graph"
  Just v -> v

-- | Walk a reified AST in topo-sorted order, accumulating analysis
-- results.
--
-- `usedOnce` must be the final value of `unusedBindings` after analysis
-- is complete.
updateFacts :: IntSet -> (Unique, CSE Unique) -> Facts -> Facts
updateFacts _ (ident, CSEBlock block) (exprFacts, blockFacts) = (exprFacts, IntMap.insert ident (toBlock (getFact exprFacts) (getFact blockFacts) block) blockFacts)
updateFacts usedOnce (ident, CSEExpr expr) (exprFacts, blockFacts) = (IntMap.insert ident fact exprFacts, blockFacts)
  where
  nameOf var = AST.VarName $ "cse" ++ show var
  fact = case expr of
    ExpSimpleF e -> const $ return e
    ex -> \ ty -> do
      bindings <- get
      case Map.lookup (ident, ty) $ availableBindings bindings of
        Just var -> do
          set $ bindings { unusedBindings = IntSet.delete var $ unusedBindings bindings }
          return $ AST.ExpVar $ nameOf var
        Nothing -> do
          ex' <- fmap toExpr $ mapM (uncurry $ getFact exprFacts) $ labelTypes ty ex
          var <- sets $ \ (Bindings { availableBindings = avail, unusedBindings = unused, totalBindings = maxId}) ->
            (maxId, Bindings
              { availableBindings = Map.insert (ident, ty) maxId avail
              , unusedBindings = IntSet.insert maxId unused
              , totalBindings = maxId + 1
              })
          -- Defer a final decision on whether to inline this expression
          -- or allocate a variable for it until we've finished running
          -- the State monad and can extract the unusedBindings set from
          -- there. After that the Writer monad can make decisions based
          -- on usedOnce without throwing a <<loop>> exception.
          lift $ if var `IntSet.member` usedOnce
            then return ex'
            else do
              put $ D.singleton $ AST.Assign ty (nameOf var) ex'
              return $ AST.ExpVar $ nameOf var

-- | Values that we may generate by simplification rules on the reified
-- representation of the graph.
data Constant
  = ConstFalse
  | ConstTrue
  | ConstZero
  | ConstTwo
  deriving (Bounded, Enum)

-- | AST implementation for each constant value.
constExpr :: Constant -> CSE Unique
constExpr ConstFalse = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitBool False
constExpr ConstTrue = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitBool True
constExpr ConstZero = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitInteger 0
constExpr ConstTwo = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitInteger 2

-- | Generate a unique integer for each constant which doesn't collide
-- with any IDs that reifyGraph may generate.
constUnique :: Constant -> Unique
constUnique c = negate $ 1 + fromEnum c

-- | Wrapper around Facts to track unshared duplicates.
type Dupes = (Map (CSE Unique) Unique, IntMap Unique, Facts)

-- | Wrapper around updateFacts to remove unshared duplicates. Also,
-- checking for equality of statements or expressions is constant-time
-- in this representation, so apply any simplifications that rely on
-- equality of subtrees here.
dedup :: IntSet -> (Unique, CSE Unique) -> Dupes -> Dupes
dedup usedOnce (ident, expr) (seen, remap, facts) = case expr' of
  -- If this operator yields a constant on equal operands, we can
  -- rewrite it to that constant.
  CSEExpr (ExpOpF (AST.ExpEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ constUnique ConstTrue
  CSEExpr (ExpOpF (AST.ExpNeq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ constUnique ConstFalse
  CSEExpr (ExpOpF (AST.ExpGt isEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ if isEq then constUnique ConstTrue else constUnique ConstFalse
  CSEExpr (ExpOpF (AST.ExpLt isEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ if isEq then constUnique ConstTrue else constUnique ConstFalse
  CSEExpr (ExpOpF AST.ExpBitXor [a, b]) | a == b -> remapTo $ constUnique ConstZero
  -- NOTE: This transformation is not safe for ExpSub on floating-point
  -- values, which could be NaN.

  -- If this operator is idempotent and its operands are equal, we can
  -- replace it with either operand without changing its meaning.
  CSEExpr (ExpOpF AST.ExpAnd [a, b]) | a == b -> remapTo a
  CSEExpr (ExpOpF AST.ExpOr [a, b]) | a == b -> remapTo a
  CSEExpr (ExpOpF AST.ExpBitAnd [a, b]) | a == b -> remapTo a
  CSEExpr (ExpOpF AST.ExpBitOr [a, b]) | a == b -> remapTo a

  -- If both branches of a conditional expression or statement have the
  -- same effect, then we don't need to evaluate the condition; we can
  -- just replace it with either branch. This is not safe in C because
  -- the condition might have side effects, but Ivory expressions never
  -- have side effects.
  CSEExpr (ExpOpF AST.ExpCond [_, t, f]) | t == f -> remapTo t
  -- NOTE: This results in inserting a Block directly into another
  -- Block, which can't happen any other way.
  CSEBlock (StmtIfTE _ t f) | t == f -> remapTo t

  -- Single-statement blocks generate the same code as the statement.
  CSEBlock (Block [s]) -> remapTo s

  -- No equal subtrees, so run with it.
  _ -> case Map.lookup expr' seen of
    Just ident' -> remapTo ident'
    Nothing -> (Map.insert expr' ident seen, remap, updateFacts usedOnce (ident, expr') facts)
  where
  remapTo ident' = (seen, IntMap.insert ident ident' remap, facts)
  expr' = case fmap (\ k -> IntMap.findWithDefault k k remap) expr of
    -- Perhaps this operator can be replaced by a simpler one when its
    -- operands are equal.
    CSEExpr (ExpOpF AST.ExpAdd [a, b]) | a == b -> CSEExpr $ ExpOpF AST.ExpMul $ sort [constUnique ConstTwo, a]

    -- If this operator is commutative, we can put its arguments in any
    -- order we want. If we choose the same order every time, more
    -- semantically equivalent subexpressions will be factored out.
    CSEExpr (ExpOpF op args) | isCommutative op -> CSEExpr $ ExpOpF op $ sort args
    asis -> asis

  isFloat AST.TyFloat = True
  isFloat AST.TyDouble = True
  isFloat _ = False

  isCommutative (AST.ExpEq _) = True
  isCommutative (AST.ExpNeq _) = True
  isCommutative AST.ExpMul = True
  isCommutative AST.ExpAdd = True
  isCommutative AST.ExpBitAnd = True
  isCommutative AST.ExpBitOr = True
  isCommutative AST.ExpBitXor = True
  isCommutative _ = False

-- | Given a reified AST, reconstruct an Ivory AST with all sharing made
-- explicit.
reconstruct :: Graph CSE -> AST.Block
reconstruct (Graph subexprs root) = D.toList rootBlock
  where
  -- NOTE: `dedup` needs to merge the constants in first, which means
  -- that as long as this is a `foldr`, they need to be appended after
  -- `subexprs`. Don't try to optimize this by re-ordering the list.
  (_, remap, (_, blockFacts)) =
    foldr (dedup usedOnce) mempty
      $ subexprs ++ [ (constUnique c, constExpr c) | c <- [minBound..maxBound] ]
  Just rootGen = IntMap.lookup (IntMap.findWithDefault root root remap) blockFacts
  (((), Bindings { unusedBindings = usedOnce }), rootBlock) =
    runM rootGen $ Bindings Map.empty IntSet.empty 0