module Herbie
    ( plugin
    , pass
    )
    where

import Class
import DsBinds
import DsMonad
import ErrUtils
import GhcPlugins
import Id
import Unique
import MkId
import PrelNames
import TcRnMonad
import TcSimplify

import Control.Monad
import Control.Monad.Except
import Data.Data
import Data.Maybe
import Data.Typeable

import Herbie.CoreManip
import Herbie.ForeignInterface
import Herbie.MathExpr
import Herbie.MathInfo
import Herbie.Options

import Debug.Trace

import Prelude
import Show
import Data.IORef

plugin :: Plugin
plugin = defaultPlugin
    { installCoreToDos = install
    }

install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
install opts todo = do
    putMsgS "Compiling with Herbie floating point stabilization"
    reinitializeGlobals
    return (CoreDoPluginPass "MathInfo" (pass opts) : todo)

pass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts
pass opts guts = do
    dflags <- getDynFlags
    liftIO $ writeIORef dynFlags_ref dflags
    bindsOnlyPass (mapM (modBind opts guts)) guts

-- | This function gets run on each binding on the Haskell source file.
modBind :: [CommandLineOption] -> ModGuts -> CoreBind -> CoreM CoreBind
modBind opts guts bndr@(Rec _) = return bndr
modBind opts guts bndr@(NonRec b e) = do
--     dflags <- getDynFlags
--     putMsgS ""
--     putMsgS $ showSDoc dflags (ppr b)
--         ++ "::"
--         ++ showSDoc dflags (ppr $ varType b)
--     putMsgS $ myshow dflags e
--     return bndr
    anns <- annotationsOn guts b :: CoreM [String]
    e' <- if "NoHerbie" `elem` anns
        then return e
        else go [] e
    return $ NonRec b e'
    where
        pluginOpts = parsePluginOpts opts

        -- Recursively descend into the expression e.
        -- For each math expression we find, run Herbie on it.
        -- We need to save each dictionary we find because
        -- it might be needed to create the replacement expressions.
        go dicts e = do
            dflags <- getDynFlags
            case mkMathInfo dflags dicts (varType b) e of

                -- not a math expression, so recurse into subexpressions
                Nothing -> case e of

                    -- Lambda expression:
                    -- If the variable is a dictionary, add it to the list;
                    -- Always recurse into the subexpression
                    --
                    -- FIXME:
                    -- Currently, we're removing deadness annotations from any dead variables.
                    -- This is so that we can use all the dictionaries that the type signatures allow.
                    -- Core lint complains about using dead variables if we don't.
                    -- This causes us to remove ALL deadness annotations in the entire program.
                    -- I'm not sure the drawback of this.
                    -- This could be fixed by having a second pass through the code
                    -- to remove only the appropriate deadness annotations.
                    Lam a b -> do
                        let a' = undeadenId a
                        b' <- go (extractDicts a'++dicts) b
                        return $ Lam a' b'

                    -- Let binding:
                    -- If the variable is a dictionary, add it to the list;
                    -- Always recurse into the subexpression
                    Let (NonRec a e) b -> do
                        let a' = undeadenId a
                        e' <- go dicts e
                        b' <- go (extractDicts a'++dicts) b
                        return $ Let (NonRec a' e') b'

                    Let (Rec bndrs) expr -> do
                        bndrs' <- forM bndrs $ \(a,e) -> do
                            let a' = undeadenId a
                            e' <- go dicts e
                            return (a',e')
                        expr' <- go dicts expr
                        return $ Let (Rec bndrs') expr'

                    -- Function application:
                    -- Math expressions may appear on either side, so recurse on both
                    App a b -> do
                        a' <- go dicts a
                        b' <- go dicts b
                        return $ App a' b'

                    -- Case statement:
                    -- Math expressions may appear in the condition or in any of the branches
                    Case cond w t es -> do
                        cond' <- go dicts cond
                        es' <- forM es $ \ (altcon, xs, expr) -> do
                            expr' <- go dicts expr
                            return (altcon, xs, expr')
                        return $ Case cond' w t es'

                    -- Ticks and Casts are just annotating extra information on an expression.
                    -- We ignore the extra information and recurse into the expression.
                    Tick a b -> do
                        b' <- go dicts b
                        return $ Tick a b'

                    Cast a b -> do
                        a' <- go dicts a
                        return $ Cast a' b

                    -- There's nothing to do for these statements.
                    -- They form the recursion's base case.
                    Var v      -> return $ Var v
                    Lit l      -> return $ Lit l
                    Type t     -> return $ Type t
                    Coercion c -> return $ Coercion c

                -- We found a math expression, so process it
                Just mathInfo -> do
                    putMsgS $ "Found math expression within binding "
                        ++ showSDoc dflags (ppr b)
                        ++ " :: "
                        ++ showSDoc dflags (ppr $ varType b)
                    putMsgS $ "  original expression = "++pprMathInfo mathInfo
                    let dbgInfo = DbgInfo
                            { dbgComments  = concat opts
                            , modName      = showSDoc dflags (ppr $ moduleName $ mg_module guts)
                            , functionName = showSDoc dflags (ppr b)
                            , functionType = showSDoc dflags (ppr $ varType b)
                            }
                    res <- liftIO $ stabilizeMathExpr dbgInfo $ getMathExpr mathInfo
                    let mathInfo' = mathInfo { getMathExpr = cmdout res }

                    -- Display the improved expression if found
                    let canRewrite = True -- errin res-errout res > optsTol pluginOpts
                    if canRewrite
                        then do
                            putMsgS $ "  improved expression = "++pprMathInfo mathInfo'
                            putMsgS $ "  original error = "++show (errin res)++" bits"
                            putMsgS $ "  improved error = "++show (errout res)++" bits"
                        else do
                            putMsgS $ "  Herbie could not improve the stability of the original expression"

                    -- Rewrite the expression
                    if not (optsRewrite pluginOpts) || not canRewrite
                        then return e
                        else do
                            ret <- runExceptT $ mathInfo2expr guts mathInfo'
                            case ret of
                                Left str -> do
                                    putMsgS "  WARNING: Not substituting the improved expression into your code"
                                    putMsgS str
                                    return e
                                Right e' -> return e'

-- | Return a list with the given variable if the variable is a dictionary or tuple of dictionaries,
-- otherwise return [].
extractDicts :: Var -> [Var]
extractDicts v = case classifyPredType (varType v) of
    ClassPred _ _ -> [v]
    EqPred _ _ _  -> [v]
    TuplePred _   -> [v]
    IrredPred _   -> []

-- | If a variable is marked as dead, remove the marking
undeadenId :: Var -> Var
undeadenId a = if isDeadBinder a
    then setIdOccInfo a NoOccInfo
    else a

-- | Function taken from the docs:
-- https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/compiler-plugins.html
annotationsOn :: Data a => ModGuts -> CoreBndr -> CoreM [a]
annotationsOn guts bndr = do
  anns <- getAnnotations deserializeWithData guts
  return $ lookupWithDefaultUFM anns [] (varUnique bndr)