module Herbie.MathInfo
where
import Class
import DsBinds
import DsMonad
import ErrUtils
import GhcPlugins hiding (trace)
import Unique
import MkId
import PrelNames
import UniqSupply
import TcRnMonad
import TcSimplify
import Type
import Control.Monad
import Control.Monad.Except
import Control.Monad.Trans
import Data.Char
import Data.List
import Data.Maybe
import Data.Ratio
import Herbie.CoreManip
import Herbie.MathExpr
import Prelude
import Show
trace a b = b
traceM a = return ()
data ParamType = ParamType
{ getQuantifier :: [Var]
, getCxt :: [Type]
, getDicts :: [CoreExpr]
, getParam :: Type
}
data MathInfo = MathInfo
{ getMathExpr :: MathExpr
, getParamType :: ParamType
, getExprs :: [(String,Expr Var)]
}
pprMathInfo :: MathInfo -> String
pprMathInfo mathInfo = go 1 False $ getMathExpr mathInfo
where
isLitOrLeaf :: MathExpr -> Bool
isLitOrLeaf (ELit _ ) = True
isLitOrLeaf (ELeaf _) = True
isLitOrLeaf _ = False
go :: Int -> Bool -> MathExpr -> String
go i b e = if b && not (isLitOrLeaf e)
then "("++str++")"
else str
where
str = case e of
EMonOp op e1 -> op++" "++ go i True e1
EBinOp op e1 e2 -> go i parens1 e1++" "++op++" "++go i parens2 e2
where
parens1 = case e1 of
(EBinOp op' _ _) -> op/=op'
_ -> True
parens2 = case e2 of
(EBinOp op' _ _) -> op/=op'
_ -> True
ELit l -> if toRational (floor l) == l
then if length (show (floor l :: Integer)) < 10
then show (floor l :: Integer)
else show (fromRational l :: Double)
else show (fromRational l :: Double)
ELeaf l -> case lookup l $ getExprs mathInfo of
Just (Var _) -> l
_ -> "???"
EIf cond e1 e2 -> "if "++go i False cond++"\n"
++white++"then "++go (i+1) False e1++"\n"
++white++"else "++go (i+1) False e2
where
white = replicate (4*i) ' '
varTypeIfValidExpr :: CoreExpr -> Maybe Type
varTypeIfValidExpr e = case e of
(App (App (App (App (Var v) (Type t)) _) _) _) -> if var2str v `elem` binOpList
then if isValidType t
then Just t
else Nothing
else Nothing
(App (App (App (Var v) (Type t)) _) _) -> if var2str v `elem` monOpList
then if isValidType t
then Just t
else Nothing
else Nothing
_ -> Nothing
where
isValidType :: Type -> Bool
isValidType t = isTyVarTy t || case splitTyConApp_maybe t of
Nothing -> True
Just (tyCon,_) -> tyCon == floatTyCon || tyCon == doubleTyCon
mkMathInfo :: DynFlags -> [Var] -> Type -> Expr Var -> Maybe MathInfo
mkMathInfo dflags dicts bndType e = case varTypeIfValidExpr e of
Nothing -> Nothing
Just t -> if mathExprDepth getMathExpr>1 && lispHasRepeatVars (mathExpr2lisp getMathExpr)
then Just $ MathInfo
getMathExpr
ParamType
{ getQuantifier = quantifier
, getCxt = cxt
, getDicts = map Var dicts
, getParam = t
}
exprs
else Nothing
where
(getMathExpr,exprs) = go e []
(quantifier,unquantified) = extractQuantifiers bndType
(cxt,uncxt) = extractContext unquantified
go :: Expr Var
-> [(String,Expr Var)]
-> (MathExpr
,[(String,Expr Var)]
)
go e@(App (App (App (App (Var v) (Type _)) (Type _)) a1) a2) exprs
= if var2str v == "$"
then go (App a1 a2) exprs
else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])
go e@(App (App (App (Var v) (Type _)) dict) (Lit l)) exprs
= (ELit $ lit2rational l, exprs)
go e@(App (App (App (Var v) (Type _)) dict)
(App (App (App (Var _) (Type _)) (Lit l1)) (Lit l2))) exprs
= (ELit $ lit2rational l1 / lit2rational l2, exprs)
go e@(App (Var _) (Lit l)) exprs
= (ELit $ lit2rational l, exprs)
go e@(App (App (App (App (Var v) (Type _)) dict) a1) a2) exprs
= if var2str v `elem` binOpList
then let (a1',exprs1) = go a1 []
(a2',exprs2) = go a2 []
in ( EBinOp (var2str v) a1' a2'
, exprs++exprs1++exprs2
)
else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])
go e@(App (App (App (Var v) (Type _)) dict) a) exprs
= if var2str v `elem` monOpList
then let (a',exprs') = go a []
in ( EMonOp (var2str v) a'
, exprs++exprs'
)
else (ELeaf $ expr2str dflags e,(expr2str dflags e,e):exprs)
go e exprs = (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])
mathInfo2expr :: ModGuts -> MathInfo -> ExceptT String CoreM CoreExpr
mathInfo2expr guts herbie = go (getMathExpr herbie)
where
pt = getParamType herbie
go (EBinOp opstr a1 a2) = do
a1' <- go a1
a2' <- go a2
f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt)
return $ App (App f a1') a2'
go (EMonOp opstr a) = do
a' <- go a
f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt)
castToType
(getDicts pt)
(getParam pt)
$ App f a'
go (EIf cond a1 a2) = do
cond' <- go cond >>= castToType (getDicts pt) boolTy
a1' <- go a1
a2' <- go a2
wildUniq <- getUniqueM
let wildName = mkSystemName wildUniq (mkVarOcc "wild")
wildVar = mkLocalVar VanillaId wildName boolTy vanillaIdInfo
return $ Case
cond'
wildVar
(getParam pt)
[ (DataAlt falseDataCon, [], a2')
, (DataAlt trueDataCon, [], a1')
]
go (ELit r) = do
fromRationalExpr <- getDecoratedFunction guts "fromRational" (getParam pt) (getDicts pt)
integerTyCon <- lookupTyCon integerTyConName
let integerTy = mkTyConTy integerTyCon
ratioTyCon <- lookupTyCon ratioTyConName
tmpUniq <- getUniqueM
let tmpName = mkSystemName tmpUniq (mkVarOcc "a")
tmpVar = mkTyVar tmpName liftedTypeKind
tmpVarT = mkTyVarTy tmpVar
ratioConTy = mkForAllTy tmpVar $ mkFunTys [tmpVarT,tmpVarT] $ mkAppTy (mkTyConTy ratioTyCon) tmpVarT
ratioConVar = mkGlobalVar VanillaId ratioDataConName ratioConTy vanillaIdInfo
return $ App
fromRationalExpr
(App
(App
(App
(Var ratioConVar )
(Type integerTy)
)
(Lit $ LitInteger (numerator r) integerTy)
)
(Lit $ LitInteger (denominator r) integerTy)
)
go (ELeaf str) = do
dflags <- getDynFlags
return $ case lookup str (getExprs herbie) of
Just x -> x
Nothing -> error $ "mathInfo2expr: var " ++ str ++ " not in scope"
++"; in scope vars="++show (nub $ map fst $ getExprs herbie)