-- 
-- This module provides functions for program transformations.
--
module CLasH.Normalize.NormalizeTools where

-- Standard modules
import qualified Data.Monoid as Monoid
import qualified Data.Either as Either
import qualified Control.Monad as Monad
import qualified Control.Monad.Trans.Writer as Writer
import qualified Control.Monad.Trans.Class as Trans
import qualified Data.Accessor.Monad.Trans.State as MonadState

-- GHC API
import CoreSyn
import qualified Name
import qualified Id
import qualified CoreSubst
import qualified Type
import qualified CoreUtils
import Outputable ( showSDoc, ppr, nest )

-- Local imports
import CLasH.Normalize.NormalizeTypes
import CLasH.Translator.TranslatorTypes
import CLasH.VHDL.Constants (builtinIds)
import CLasH.Utils
import qualified CLasH.Utils.Core.CoreTools as CoreTools
import qualified CLasH.VHDL.VHDLTools as VHDLTools

-- Apply the given transformation to all expressions in the given expression,
-- including the expression itself.
everywhere :: Transform -> Transform
everywhere trans = applyboth (subeverywhere (everywhere trans)) trans

data NormDbgLevel = 
    NormDbgNone         -- ^ No debugging
  | NormDbgFinal        -- ^ Print functions before / after normalization
  | NormDbgApplied      -- ^ Print expressions before / after applying transformations
  | NormDbgAll          -- ^ Print expressions when a transformation does not apply
  deriving (Eq, Ord)
normalize_debug = NormDbgFinal

-- Applies a transform, optionally showing some debug output.
apply :: (String, Transform) -> Transform
apply (name, trans) ctx expr =  do
    -- Apply the transformation and find out if it changed anything
    (expr', any_changed) <- Writer.listen $ trans ctx expr
    let changed = Monoid.getAny any_changed
    -- If it changed, increase the transformation counter 
    Monad.when changed $ Trans.lift (MonadState.modify tsTransformCounter (+1))
    -- Prepare some debug strings
    let before = showSDoc (nest 4 $ ppr expr) ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr) ++ "\n"
    let context = "Context: " ++ show ctx ++ "\n"
    let after  = showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
    traceIf (normalize_debug >= NormDbgApplied && changed) ("Changes when applying transform " ++ name ++ " to:\n" ++ before ++ context ++ "Result:\n" ++ after) $ 
     traceIf (normalize_debug >= NormDbgAll && not changed) ("No changes when applying transform " ++ name ++ " to:\n" ++ before  ++ context) $
     return expr'

-- Apply the first transformation, followed by the second transformation, and
-- keep applying both for as long as expression still changes.
applyboth :: Transform -> Transform -> Transform
applyboth first second context expr = do
  -- Apply the first
  expr' <- first context expr
  -- Apply the second
  (expr'', changed) <- Writer.listen $ second context expr'
  if Monoid.getAny $ changed
    then
      applyboth first second context expr'' 
    else 
      return expr''

-- Apply the given transformation to all direct subexpressions (only), not the
-- expression itself.
subeverywhere :: Transform -> Transform
subeverywhere trans c (App a b) = do
  a' <- trans (AppFirst:c) a
  b' <- trans (AppSecond:c) b
  return $ App a' b'

subeverywhere trans c (Let (NonRec b bexpr) expr) = do
  bexpr' <- trans (LetBinding:c) bexpr
  expr' <- trans (LetBody:c) expr
  return $ Let (NonRec b bexpr') expr'

subeverywhere trans c (Let (Rec binds) expr) = do
  expr' <- trans (LetBody:c) expr
  binds' <- mapM transbind binds
  return $ Let (Rec binds') expr'
  where
    transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
    transbind (b, e) = do
      e' <- trans (LetBinding:c) e
      return (b, e')

subeverywhere trans c (Lam x expr) = do
  expr' <- trans (LambdaBody:c) expr
  return $ Lam x expr'

subeverywhere trans c (Case scrut b t alts) = do
  scrut' <- trans (Other:c) scrut
  alts' <- mapM transalt alts
  return $ Case scrut' b t alts'
  where
    transalt :: CoreAlt -> TransformMonad CoreAlt
    transalt (con, binders, expr) = do
      expr' <- trans (Other:c) expr
      return (con, binders, expr')

subeverywhere trans c (Var x) = return $ Var x
subeverywhere trans c (Lit x) = return $ Lit x
subeverywhere trans c (Type x) = return $ Type x

subeverywhere trans c (Cast expr ty) = do
  expr' <- trans (Other:c) expr
  return $ Cast expr' ty

subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr

-- Runs each of the transforms repeatedly inside the State monad.
dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
dotransforms transs expr = do
  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere (apply trans) [] e) expr transs
  if Monoid.getAny changed then dotransforms transs expr' else return expr'

-- Inline all let bindings that satisfy the given condition
inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
inlinebind condition context expr@(Let (Rec binds) res) = do
    -- Find all bindings that adhere to the condition
    res_eithers <- mapM docond binds
    case Either.partitionEithers res_eithers of
      -- No replaces? No change
      ([], _) -> return expr
      (replace, others) -> do
        -- Substitute the to be replaced binders with their expression
        newexpr <- do_substitute replace (Let (Rec others) res)
        change newexpr
  where 
    -- Apply the condition to a let binding and return an Either
    -- depending on whether it needs to be inlined or not.
    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
    docond b = do
      res <- condition b
      return $ case res of True -> Left b; False -> Right b

    -- Apply the given list of substitutions to the the given expression
    do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
    do_substitute [] expr = return expr
    do_substitute ((bndr, val):reps) expr = do
      -- Perform this substitution in the expression
      expr' <- substitute_clone bndr val context expr
      -- And in the substitution values we will be using next
      reps' <- mapM (subs_bind bndr val) reps
      -- And then perform the remaining substitutions
      do_substitute reps' expr'
   
    -- Replace the given binder with the given expression in the
    -- expression oft the given let binding
    subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
    subs_bind bndr expr (b, v) = do
      v' <- substitute_clone  bndr expr (LetBinding:context) v
      return (b, v')


-- Leave all other expressions unchanged
inlinebind _ context expr = return expr

-- Sets the changed flag in the TransformMonad, to signify that some
-- transform has changed the result
setChanged :: TransformMonad ()
setChanged = Writer.tell (Monoid.Any True)

-- Sets the changed flag and returns the given value.
change :: a -> TransformMonad a
change val = do
  setChanged
  return val

-- Returns the given value and sets the changed flag if the bool given is
-- True. Note that this will not unset the changed flag if the bool is False.
changeif :: Bool -> a -> TransformMonad a
changeif True val = change val
changeif False val = return val

-- | Creates a transformation that substitutes the given binder with the given
-- expression (This can be a type variable, replace by a Type expression).
-- Does not set the changed flag.
substitute :: CoreBndr -> CoreExpr -> Transform
-- Use CoreSubst to subst a type var in an expression
substitute find repl context expr = do
  let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
  return $ CoreSubst.substExpr subst expr 

-- | Creates a transformation that substitutes the given binder with the given
-- expression. This does only work for value expressions! All binders in the
-- expression are cloned before the replacement, to guarantee uniqueness.
substitute_clone :: CoreBndr -> CoreExpr -> Transform
-- If we see the var to find, replace it by a uniqued version of repl
substitute_clone find repl context (Var var) | find == var = do
  repl' <- Trans.lift $ CoreTools.genUniques repl
  change repl'

-- For all other expressions, just look in subexpressions
substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr

-- Is the given expression representable at runtime, based on the type?
isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
isRepr tything = Trans.lift (isRepr' tything)

isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
isRepr' tything = case CoreTools.getType tything of
  Nothing -> return False
  Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 

is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
is_local_var (CoreSyn.Var v) = do
  bndrs <- getGlobalBinders
  return $ v `notElem` bndrs
is_local_var _ = return False

-- Is the given binder defined by the user?
isUserDefined :: CoreSyn.CoreBndr -> Bool
-- System names are certain to not be user defined
isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
-- Builtin functions are usually not user-defined either (and would
-- break currently if they are...)
isUserDefined bndr = str `notElem` builtinIds
  where
    str = Name.getOccString bndr

-- | Is the given binder normalizable? This means that its type signature can be
-- represented in hardware, which should (?) guarantee that it can be made
-- into hardware. This checks whether all the arguments and (optionally)
-- the return value are
-- representable.
isNormalizeable :: 
  Bool -- ^ Allow the result to be unrepresentable?
  -> CoreBndr  -- ^ The binder to check
  -> TranslatorSession Bool  -- ^ Is it normalizeable?
isNormalizeable result_nonrep bndr = do
  let ty = Id.idType bndr
  let (arg_tys, res_ty) = Type.splitFunTys ty
  let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
  andM $ mapM isRepr' check_tys