{-# LANGUAGE CPP #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -Wall #-} module ConstMath.Pass ( constMathProgram ) where import ConstMath.Types import ConstMath.PrimRules import Control.Applicative ((<$>)) import Control.Monad ((<=<)) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe import GhcPlugins constMathProgram :: Int -> Opts -> [CoreBind] -> CoreM [CoreBind] constMathProgram n opts binds = do traceMsg opts $ "\nStarting ConstMath pass - " ++ show n mapM (subBind opts "") binds subBind :: Opts -> String -> CoreBind -> CoreM CoreBind subBind opts tab (NonRec b rhs) = do tracePretty opts (tab ++ "Non-recursive binding named ") b rhs' <- subExpr opts tab rhs return (NonRec b rhs') subBind opts _tab bndr@(Rec pairs) = do _ <- mapM (uncurry $ printRecBind opts) pairs return bndr printRecBind :: Opts -> CoreBndr -> Expr CoreBndr -> CoreM () printRecBind opts b _e = do tracePretty opts "Recursive binding " b subExpr :: Opts -> String -> CoreExpr -> CoreM CoreExpr subExpr opts tab expr@(Type t) = do tracePretty opts (tab ++ "Type ") t return expr subExpr opts tab expr@(Coercion _co) = do traceMsg opts (tab ++ "Coercion") return expr subExpr opts tab expr@(Lit lit) = do tracePretty opts (tab ++ "Lit ") lit return expr subExpr opts tab expr@(Var v) = do tracePretty opts (tab ++ "Var ") v return expr subExpr opts tab (App f a) = do tracePretty opts (tab ++ "App ") f f' <- subExpr opts (tab ++ "< ") f a' <- subExpr opts (tab ++ "> ") a collapse opts (App f' a') subExpr opts tab (Tick t e) = do traceMsg opts (tab ++ "Tick") e' <- subExpr opts (tab ++ " ") e return (Tick t e') subExpr opts tab (Cast e co) = do traceMsg opts (tab ++ "Cast") e' <- subExpr opts (tab ++ " ") e return (Cast e' co) subExpr opts tab (Lam b e) = do traceMsg opts (tab ++ "Lam") e' <- subExpr opts (tab ++ " ") e return (Lam b e') subExpr opts tab (Let bind e) = do traceMsg opts (tab ++ "Let") bind' <- subBind opts tab bind e' <- subExpr opts (tab ++ " ") e return (Let bind' e') subExpr opts tab (Case scrut bndr ty alts) = do traceMsg opts (tab ++ "Case") let subAlt (ac,bs,eB) = (ac,bs,) <$> subExpr opts (tab ++ " ") eB scrut' <- subExpr opts (tab ++ " ") scrut alts' <- mapM subAlt alts return (Case scrut' bndr ty alts') ---------------------------------------------------------------------- collapse :: Opts -> CoreExpr -> CoreM CoreExpr collapse opts expr@(App f1 _) | Just f <- cmSubst <$> findSub f1 = f opts expr collapse _ expr = return expr mkUnaryCollapseIEEE :: (forall a. RealFloat a => (a -> a)) -> Opts -> CoreExpr -> CoreM CoreExpr mkUnaryCollapseIEEE fnE opts expr@(App f1 (App f2 (Lit lit))) | isDHash f2, MachDouble d <- lit = e d mkDoubleLitDouble | isFHash f2, MachFloat d <- lit = e d mkFloatLitFloat where e d = evalUnaryIEEE opts fnE f1 f2 d expr mkUnaryCollapseIEEE _ _ expr = return expr evalUnaryIEEE :: (Fractional a, RealFloat b) => Opts -> (a -> b) -> CoreExpr -> CoreExpr -> Rational -> CoreExpr -> (b -> Arg Var) -> CoreM (CoreExpr) evalUnaryIEEE opts fnE f1 f2 d expr mkLit = do let sub = fnE (fromRational d) maybe (return expr) (return . App f2 . mkLit) =<< maybeIEEE opts (fromJust $ funcName f1) sub mkUnaryCollapseNum :: (forall a . Num a => (a -> a)) -> Opts -> CoreExpr -> CoreM CoreExpr mkUnaryCollapseNum fnE opts expr@(App f1 (App f2 (Lit lit))) | isDHash f2, MachDouble d <- lit = evalUnaryIEEE opts fnE f1 f2 d expr mkDoubleLitDouble | isFHash f2, MachFloat d <- lit = evalUnaryIEEE opts fnE f1 f2 d expr mkFloatLitFloat | isIHash f2, MachInt d <- lit = evalUnaryNum fromIntegral d mkIntLitInt | isWHash f2, MachWord d <- lit = evalUnaryNum fromIntegral d mkWordLitWord where msgResult = msg opts $ "Result of replacing " ++ (fromJust (funcName f1)) ++ " is ok" evalUnaryNum from d mkLit | dry opts = do msgResult msg opts "dry running, skipping replacement" return expr | otherwise = do let sub = fnE (from d) msgResult return (App f2 (mkLit sub)) mkUnaryCollapseNum _ _ expr = return expr mkBinaryCollapse :: (forall a. RealFloat a => (a -> a -> a)) -> Opts -> CoreExpr -> CoreM CoreExpr mkBinaryCollapse fnE opts expr@(App (App f1 (App f2 (Lit lit1))) (App f3 (Lit lit2))) | isDHash f2 && isDHash f3 , MachDouble d1 <- lit1, MachDouble d2 <- lit2 = evalBinaryIEEE d1 d2 mkDoubleLitDouble | isFHash f2 && isFHash f3 , MachFloat d1 <- lit1, MachFloat d2 <- lit2 = evalBinaryIEEE d1 d2 mkFloatLitFloat where evalBinaryIEEE d1 d2 mkLit = do let sub = fnE (fromRational d1) (fromRational d2) maybe (return expr) (\x -> return (App f2 (mkLit x))) =<< maybeIEEE opts (fromJust $ funcName f1) sub mkBinaryCollapse _ _ expr = return expr ---------------------------------------------------------------------- -- primop collapsing functions mkUnaryCollapsePrimIEEE :: (forall a. RealFloat a => (a -> a)) -> Opts -> CoreExpr -> CoreM CoreExpr mkUnaryCollapsePrimIEEE fnE opts expr@(App f1 (Lit lit)) | MachDouble d <- lit = e d mkDoubleLitDouble | MachFloat d <- lit = e d mkFloatLitFloat where e d mkLit = let sub = fnE (fromRational d) in maybe (return expr) (return . mkLit) =<< maybeIEEE opts (fromJust $ funcName f1) sub mkUnaryCollapsePrimIEEE _ _ expr = return expr mkBinaryCollapsePrimIEEE :: (forall a. RealFloat a => (a -> a -> a)) -> Opts -> CoreExpr -> CoreM CoreExpr mkBinaryCollapsePrimIEEE fnE opts expr@(App (App primVar (Lit lit1)) (Lit lit2)) | MachDouble d1 <- lit1 , MachDouble d2 <- lit2 = e d1 d2 mkDoubleLitDouble | MachFloat d1 <- lit1 , MachFloat d2 <- lit2 = e d1 d2 mkFloatLitFloat where e d1 d2 mkLit = let sub = fnE (fromRational d1) (fromRational d2) in maybe (return expr) (return . mkLit) =<< maybeIEEE opts (fromJust $ funcName primVar) sub mkBinaryCollapsePrimIEEE _ _ expr = return expr ---------------------------------------------------------------------- -- specialized collapsing functions fromRationalCollapse :: Opts -> CoreExpr -> CoreM CoreExpr fromRationalCollapse opts expr@(App f1@(Var frFn) (App (App f2 (Lit (LitInteger n _))) (Lit (LitInteger d _)))) | Just (_arg,res) <- splitFunTy_maybe $ varType frFn , Just "GHC.Real.:%" <- funcName f2 , Just fnNm <- funcName f1 = case () of _ | res `eqType` floatTy -> do let sub = fromRational $ (fromInteger n) / (fromInteger d) maybe (return expr) (\x -> return (mkFloatExpr x)) =<< maybeIEEE opts (fnNm) sub | res `eqType` doubleTy -> do let sub = fromRational $ (fromInteger n) / (fromInteger d) maybe (return expr) (\x -> return (mkDoubleExpr x)) =<< maybeIEEE opts (fnNm) sub | otherwise -> return expr fromRationalCollapse _opts expr = return expr maybeIEEE :: RealFloat a => Opts -> String -> a -> CoreM (Maybe a) maybeIEEE opts s d | isNaN d = do err "NaN" return Nothing | isInfinite d = do err "infinite" return Nothing | isDenormalized d = do err "denormalized" return Nothing | isNegativeZero d = do err "negative zero" return Nothing | otherwise = do msg opts $ "Result of replacing " ++ s ++ " is ok" if (dry opts) then msg opts "Dry run, skipping replacement" >> return Nothing else return (Just d) where err v = errorMsgS $ "Skipping replacement of " ++ s ++ " result " ++ v ---------------------------------------------------------------------- data CMSub = CMSub { cmFuncName :: String , cmSubst :: Opts -> CoreExpr -> CoreM CoreExpr } unarySubIEEE :: String -> (forall a. RealFloat a => a -> a) -> CMSub unarySubIEEE nm fn = CMSub nm (mkUnaryCollapseIEEE fn) unarySubNum :: String -> (forall a . Num a => (a -> a)) -> CMSub unarySubNum nm fn = CMSub nm (mkUnaryCollapseNum fn) _binarySub :: String -> (forall a. RealFloat a => a -> a -> a) -> CMSub _binarySub nm fn = CMSub nm (mkBinaryCollapse fn) unaryPrimIEEE :: String -> (forall a. RealFloat a => a -> a) -> CMSub unaryPrimIEEE nm fn = CMSub nm (mkUnaryCollapsePrimIEEE fn) binaryPrimIEEE :: String -> (forall a. RealFloat a => a -> a -> a) -> CMSub binaryPrimIEEE nm fn = CMSub nm (mkBinaryCollapsePrimIEEE fn) ---------------------------------------------------------------------- isFHash :: CoreExpr -> Bool isFHash = funcIs "GHC.Types.F#" isDHash :: CoreExpr -> Bool isDHash = funcIs "GHC.Types.D#" isIHash :: CoreExpr -> Bool isIHash = funcIs "GHC.Types.I#" isWHash :: CoreExpr -> Bool isWHash = funcIs "GHC.Word.W#" funcIs :: String -> CoreExpr -> Bool funcIs s = maybe False (== s) . funcName funcName :: CoreExpr -> Maybe String funcName (Var var) = Just $ m ++ (unpackFS . occNameFS . nameOccName $ n) where n = varName var m | isExternalName n = (moduleNameString . moduleName . nameModule $ n) ++ "." | otherwise = "" funcName (App f _) = funcName f funcName _ = Nothing findSub :: CoreExpr -> Maybe CMSub findSub = flip Map.lookup subFunc <=< funcName subFunc :: Map String CMSub subFunc = Map.fromList $ zip (map cmFuncName subs) subs subs :: [CMSub] subs = [ unarySubIEEE "GHC.Float.exp" exp , unarySubIEEE "GHC.Float.log" log , unarySubIEEE "GHC.Float.sqrt" sqrt , unarySubIEEE "GHC.Float.sin" sin , unarySubIEEE "GHC.Float.cos" cos , unarySubIEEE "GHC.Float.tan" tan , unarySubIEEE "GHC.Float.asin" asin , unarySubIEEE "GHC.Float.acos" acos , unarySubIEEE "GHC.Float.atan" atan , unarySubIEEE "GHC.Float.sinh" sinh , unarySubIEEE "GHC.Float.cosh" cosh , unarySubIEEE "GHC.Float.tanh" tanh , unarySubIEEE "GHC.Float.asinh" asinh , unarySubIEEE "GHC.Float.acosh" acosh , unarySubIEEE "GHC.Float.atanh" atanh , unarySubNum "GHC.Num.negate" negate , unarySubNum "GHC.Num.abs" abs , unarySubNum "GHC.Num.signum" signum -- Specialized substitutions , CMSub "GHC.Real.fromRational" fromRationalCollapse , CMSub "GHC.Float.$fFractionalFloat_$cfromRational" fromRationalCollapse , CMSub "GHC.Float.$fFractionalDouble_$cfromRational" fromRationalCollapse ] -- PrimOp substitutions ++ map (uncurry unaryPrimIEEE) unaryPrimRules ++ map (uncurry binaryPrimIEEE) binaryPrimRules ---------------------------------------------------------------------- msg :: Opts -> String -> CoreM () msg opts s | not (quiet opts) = putMsgS s | otherwise = return () _vMsg :: Opts -> String -> CoreM () _vMsg opts s | verbose opts = putMsgS s | otherwise = return () traceMsg :: Opts -> String -> CoreM () traceMsg opts s | traced opts = putMsgS s | otherwise = return () tracePretty :: Outputable a => Opts -> String -> a -> CoreM () tracePretty opts s x = do p <- pretty x traceMsg opts (s ++ p) ---------------------------------------------------------------------- pretty :: Outputable a => a -> CoreM String #if __GLASGOW_HASKELL__ >= 706 pretty x = do dflags <- getDynFlags return $ showSDoc dflags (ppr x) #else pretty = return . showSDoc . ppr #endif