module Herbie
( plugin
, pass
)
where
import Class
import DsBinds
import DsMonad
import ErrUtils
import GhcPlugins
import Id
import Unique
import MkId
import PrelNames
import TcRnMonad
import TcSimplify
import Control.Monad
import Control.Monad.Except
import Data.Data
import Data.Maybe
import Data.Typeable
import Herbie.CoreManip
import Herbie.ForeignInterface
import Herbie.MathExpr
import Herbie.MathInfo
import Herbie.Options
import Debug.Trace
import Prelude
import Show
import Data.IORef
plugin :: Plugin
plugin = defaultPlugin
{ installCoreToDos = install
}
install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
install opts todo = do
putMsgS "Compiling with Herbie floating point stabilization"
reinitializeGlobals
return (CoreDoPluginPass "MathInfo" (pass opts) : todo)
pass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts
pass opts guts = do
dflags <- getDynFlags
liftIO $ writeIORef dynFlags_ref dflags
bindsOnlyPass (mapM (modBind opts guts)) guts
modBind :: [CommandLineOption] -> ModGuts -> CoreBind -> CoreM CoreBind
modBind opts guts bndr@(Rec _) = return bndr
modBind opts guts bndr@(NonRec b e) = do
anns <- annotationsOn guts b :: CoreM [String]
e' <- if "NoHerbie" `elem` anns
then return e
else go [] e
return $ NonRec b e'
where
pluginOpts = parsePluginOpts opts
go dicts e = do
dflags <- getDynFlags
case mkMathInfo dflags dicts (varType b) e of
Nothing -> case e of
Lam a b -> do
let a' = undeadenId a
b' <- go (extractDicts a'++dicts) b
return $ Lam a' b'
Let (NonRec a e) b -> do
let a' = undeadenId a
e' <- go dicts e
b' <- go (extractDicts a'++dicts) b
return $ Let (NonRec a' e') b'
Let (Rec bndrs) expr -> do
bndrs' <- forM bndrs $ \(a,e) -> do
let a' = undeadenId a
e' <- go dicts e
return (a',e')
expr' <- go dicts expr
return $ Let (Rec bndrs') expr'
App a b -> do
a' <- go dicts a
b' <- go dicts b
return $ App a' b'
Case cond w t es -> do
cond' <- go dicts cond
es' <- forM es $ \ (altcon, xs, expr) -> do
expr' <- go dicts expr
return (altcon, xs, expr')
return $ Case cond' w t es'
Tick a b -> do
b' <- go dicts b
return $ Tick a b'
Cast a b -> do
a' <- go dicts a
return $ Cast a' b
Var v -> return $ Var v
Lit l -> return $ Lit l
Type t -> return $ Type t
Coercion c -> return $ Coercion c
Just mathInfo -> do
putMsgS $ "Found math expression within binding "
++ showSDoc dflags (ppr b)
++ " :: "
++ showSDoc dflags (ppr $ varType b)
putMsgS $ " original expression = "++pprMathInfo mathInfo
let dbgInfo = DbgInfo
{ dbgComments = concat opts
, modName = showSDoc dflags (ppr $ moduleName $ mg_module guts)
, functionName = showSDoc dflags (ppr b)
, functionType = showSDoc dflags (ppr $ varType b)
}
res <- liftIO $ stabilizeMathExpr dbgInfo $ getMathExpr mathInfo
let mathInfo' = mathInfo { getMathExpr = cmdout res }
let canRewrite = True
if canRewrite
then do
putMsgS $ " improved expression = "++pprMathInfo mathInfo'
putMsgS $ " original error = "++show (errin res)++" bits"
putMsgS $ " improved error = "++show (errout res)++" bits"
else do
putMsgS $ " Herbie could not improve the stability of the original expression"
if not (optsRewrite pluginOpts) || not canRewrite
then return e
else do
ret <- runExceptT $ mathInfo2expr guts mathInfo'
case ret of
Left str -> do
putMsgS " WARNING: Not substituting the improved expression into your code"
putMsgS str
return e
Right e' -> return e'
extractDicts :: Var -> [Var]
extractDicts v = case classifyPredType (varType v) of
ClassPred _ _ -> [v]
EqPred _ _ _ -> [v]
TuplePred _ -> [v]
IrredPred _ -> []
undeadenId :: Var -> Var
undeadenId a = if isDeadBinder a
then setIdOccInfo a NoOccInfo
else a
annotationsOn :: Data a => ModGuts -> CoreBndr -> CoreM [a]
annotationsOn guts bndr = do
anns <- getAnnotations deserializeWithData guts
return $ lookupWithDefaultUFM anns [] (varUnique bndr)