module Herbie ( plugin , pass ) where import Class import DsBinds import DsMonad import ErrUtils import GhcPlugins import Id import Unique import MkId import PrelNames import TcRnMonad import TcSimplify import Control.Monad import Control.Monad.Except import Data.Data import Data.Maybe import Data.Typeable import Herbie.CoreManip import Herbie.ForeignInterface import Herbie.MathExpr import Herbie.MathInfo import Herbie.Options import Debug.Trace import Prelude import Show import Data.IORef plugin :: Plugin plugin = defaultPlugin { installCoreToDos = install } install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo] install opts todo = do putMsgS "Compiling with Herbie floating point stabilization" reinitializeGlobals return (CoreDoPluginPass "MathInfo" (pass opts) : todo) pass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts pass opts guts = do dflags <- getDynFlags liftIO $ writeIORef dynFlags_ref dflags bindsOnlyPass (mapM (modBind opts guts)) guts -- | This function gets run on each binding on the Haskell source file. modBind :: [CommandLineOption] -> ModGuts -> CoreBind -> CoreM CoreBind modBind opts guts bndr@(Rec _) = return bndr modBind opts guts bndr@(NonRec b e) = do -- dflags <- getDynFlags -- putMsgS "" -- putMsgS $ showSDoc dflags (ppr b) -- ++ "::" -- ++ showSDoc dflags (ppr $ varType b) -- putMsgS $ myshow dflags e -- return bndr anns <- annotationsOn guts b :: CoreM [String] e' <- if "NoHerbie" `elem` anns then return e else go [] e return $ NonRec b e' where pluginOpts = parsePluginOpts opts -- Recursively descend into the expression e. -- For each math expression we find, run Herbie on it. -- We need to save each dictionary we find because -- it might be needed to create the replacement expressions. go dicts e = do dflags <- getDynFlags case mkMathInfo dflags dicts (varType b) e of -- not a math expression, so recurse into subexpressions Nothing -> case e of -- Lambda expression: -- If the variable is a dictionary, add it to the list; -- Always recurse into the subexpression -- -- FIXME: -- Currently, we're removing deadness annotations from any dead variables. -- This is so that we can use all the dictionaries that the type signatures allow. -- Core lint complains about using dead variables if we don't. -- This causes us to remove ALL deadness annotations in the entire program. -- I'm not sure the drawback of this. -- This could be fixed by having a second pass through the code -- to remove only the appropriate deadness annotations. Lam a b -> do let a' = undeadenId a b' <- go (extractDicts a'++dicts) b return $ Lam a' b' -- Let binding: -- If the variable is a dictionary, add it to the list; -- Always recurse into the subexpression Let (NonRec a e) b -> do let a' = undeadenId a e' <- go dicts e b' <- go (extractDicts a'++dicts) b return $ Let (NonRec a' e') b' Let (Rec bndrs) expr -> do bndrs' <- forM bndrs $ \(a,e) -> do let a' = undeadenId a e' <- go dicts e return (a',e') expr' <- go dicts expr return $ Let (Rec bndrs') expr' -- Function application: -- Math expressions may appear on either side, so recurse on both App a b -> do a' <- go dicts a b' <- go dicts b return $ App a' b' -- Case statement: -- Math expressions may appear in the condition or in any of the branches Case cond w t es -> do cond' <- go dicts cond es' <- forM es $ \ (altcon, xs, expr) -> do expr' <- go dicts expr return (altcon, xs, expr') return $ Case cond' w t es' -- Ticks and Casts are just annotating extra information on an expression. -- We ignore the extra information and recurse into the expression. Tick a b -> do b' <- go dicts b return $ Tick a b' Cast a b -> do a' <- go dicts a return $ Cast a' b -- There's nothing to do for these statements. -- They form the recursion's base case. Var v -> return $ Var v Lit l -> return $ Lit l Type t -> return $ Type t Coercion c -> return $ Coercion c -- We found a math expression, so process it Just mathInfo -> do putMsgS $ "Found math expression within binding " ++ showSDoc dflags (ppr b) ++ " :: " ++ showSDoc dflags (ppr $ varType b) putMsgS $ " original expression = "++pprMathInfo mathInfo let dbgInfo = DbgInfo { dbgComments = concat opts , modName = showSDoc dflags (ppr $ moduleName $ mg_module guts) , functionName = showSDoc dflags (ppr b) , functionType = showSDoc dflags (ppr $ varType b) } res <- liftIO $ stabilizeMathExpr dbgInfo $ getMathExpr mathInfo let mathInfo' = mathInfo { getMathExpr = cmdout res } -- Display the improved expression if found let canRewrite = True -- errin res-errout res > optsTol pluginOpts if canRewrite then do putMsgS $ " improved expression = "++pprMathInfo mathInfo' putMsgS $ " original error = "++show (errin res)++" bits" putMsgS $ " improved error = "++show (errout res)++" bits" else do putMsgS $ " Herbie could not improve the stability of the original expression" -- Rewrite the expression if not (optsRewrite pluginOpts) || not canRewrite then return e else do ret <- runExceptT $ mathInfo2expr guts mathInfo' case ret of Left str -> do putMsgS " WARNING: Not substituting the improved expression into your code" putMsgS str return e Right e' -> return e' -- | Return a list with the given variable if the variable is a dictionary or tuple of dictionaries, -- otherwise return []. extractDicts :: Var -> [Var] extractDicts v = case classifyPredType (varType v) of ClassPred _ _ -> [v] EqPred _ _ _ -> [v] TuplePred _ -> [v] IrredPred _ -> [] -- | If a variable is marked as dead, remove the marking undeadenId :: Var -> Var undeadenId a = if isDeadBinder a then setIdOccInfo a NoOccInfo else a -- | Function taken from the docs: -- https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/compiler-plugins.html annotationsOn :: Data a => ModGuts -> CoreBndr -> CoreM [a] annotationsOn guts bndr = do anns <- getAnnotations deserializeWithData guts return $ lookupWithDefaultUFM anns [] (varUnique bndr)