module Herbie.MathExpr
where
import Control.DeepSeq
import Data.List
import Data.List.Split
import Data.Maybe
import GHC.Generics
import Debug.Trace
import Prelude
ifThenElse True t f = t
ifThenElse False t f = f
monOpList =
[ "cos"
, "sin"
, "tan"
, "acos"
, "asin"
, "atan"
, "cosh"
, "sinh"
, "tanh"
, "exp"
, "log"
, "sqrt"
, "abs"
, "size"
]
binOpList = [ "^", "**", "^^", "/", "-", "expt" ] ++ commutativeOpList
commutativeOpList = [ "*", "+"]
data MathExpr
= EBinOp String MathExpr MathExpr
| EMonOp String MathExpr
| EIf MathExpr MathExpr MathExpr
| ELit Rational
| ELeaf String
deriving (Show,Eq,Generic,NFData)
instance Ord MathExpr where
compare (ELeaf _) (ELeaf _) = EQ
compare (ELeaf _) _ = LT
compare (ELit r1) (ELit r2) = compare r1 r2
compare (ELit _ ) (ELeaf _) = GT
compare (ELit _ ) _ = LT
compare (EMonOp op1 e1) (EMonOp op2 e2) = case compare op1 op2 of
EQ -> compare e1 e2
x -> x
compare (EMonOp _ _) (ELeaf _) = GT
compare (EMonOp _ _) (ELit _) = GT
compare (EMonOp _ _) _ = LT
compare (EBinOp op1 e1a e1b) (EBinOp op2 e2a e2b) = case compare op1 op2 of
EQ -> case compare e1a e2a of
EQ -> compare e1b e2b
_ -> EQ
_ -> EQ
compare (EBinOp _ _ _) _ = LT
haskellOpsToHerbieOps :: MathExpr -> MathExpr
haskellOpsToHerbieOps = go
where
go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2)
where
op' = case op of
"**" -> "expt"
"^^" -> "expt"
"^" -> "expt"
x -> x
go (EMonOp op e1) = EMonOp op' (go e1)
where
op' = case op of
"size" -> "abs"
x -> x
go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
go x = x
herbieOpsToHaskellOps :: MathExpr -> MathExpr
herbieOpsToHaskellOps = go
where
go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2)
where
op' = case op of
"^" -> "**"
"expt" -> "**"
x -> x
go (EMonOp "sqr" e1) = EBinOp "*" (go e1) (go e1)
go (EMonOp op e1) = EMonOp op' (go e1)
where
op' = case op of
"-" -> "negate"
"abs" -> "size"
x -> x
go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
go x = x
toCanonicalMathExpr :: MathExpr -> (MathExpr,[(String,String)])
toCanonicalMathExpr e = go [] e
where
go :: [(String,String)] -> MathExpr -> (MathExpr,[(String,String)])
go acc (EBinOp op e1 e2) = (EBinOp op e1' e2',acc2')
where
(e1_,e2_) = if op `elem` commutativeOpList
then (min e1 e2,max e1 e2)
else (e1,e2)
(e1',acc1') = go acc e1_
(e2',acc2') = go acc1' e2_
go acc (EMonOp op e1) = (EMonOp op e1', acc1')
where
(e1',acc1') = go acc e1
go acc (ELit r) = (ELit r,acc)
go acc (ELeaf str) = (ELeaf str',acc')
where
(acc',str') = case lookup str acc of
Nothing -> ((str,"herbie"++show (length acc)):acc, "herbie"++show (length acc))
Just x -> (acc,x)
fromCanonicalMathExpr :: (MathExpr,[(String,String)]) -> MathExpr
fromCanonicalMathExpr (e,xs) = go e
where
xs' = map (\(a,b) -> (b,a)) xs
go (EMonOp op e1) = EMonOp op (go e1)
go (EBinOp op e1 e2) = EBinOp op (go e1) (go e2)
go (EIf (EBinOp "<" _ (ELeaf "-inf.0")) e1 e2) = go e2
go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
go (ELit r) = ELit r
go (ELeaf str) = case lookup str xs' of
Just x -> ELeaf x
Nothing -> error $ "fromCanonicalMathExpr: str="++str++"; xs="++show xs'
mathExprDepth :: MathExpr -> Int
mathExprDepth (EBinOp _ e1 e2) = 1+max (mathExprDepth e1) (mathExprDepth e2)
mathExprDepth (EMonOp _ e1 ) = 1+mathExprDepth e1
mathExprDepth _ = 0
getCanonicalLispCmd :: MathExpr -> (String,[(String,String)])
getCanonicalLispCmd me = (mathExpr2lisp me',varmap)
where
(me',varmap) = toCanonicalMathExpr me
fromCanonicalLispCmd :: (String,[(String,String)]) -> MathExpr
fromCanonicalLispCmd (lisp,varmap) = fromCanonicalMathExpr (lisp2mathExpr lisp,varmap)
mathExpr2lisp :: MathExpr -> String
mathExpr2lisp = go
where
go (EBinOp op a1 a2) = "("++op++" "++go a1++" "++go a2++")"
go (EMonOp op a) = "("++op++" "++go a++")"
go (EIf cond e1 e2) = "(if "++go cond++" "++go e1++" "++go e2++")"
go (ELeaf e) = e
go (ELit r) = if (toRational (floor r::Integer) == r)
then show (floor r :: Integer)
else show (fromRational r :: Double)
lisp2mathExpr :: String -> MathExpr
lisp2mathExpr ('-':xs) = EMonOp "negate" (lisp2mathExpr xs)
lisp2mathExpr ('(':xs) = if length xs > 1 && last xs==')'
then case groupByParens $ init xs of
[op,e1] -> EMonOp op (lisp2mathExpr e1)
[op,e1,e2] -> EBinOp op (lisp2mathExpr e1) (lisp2mathExpr e2)
["if",cond,e1,e2] -> EIf (lisp2mathExpr cond) (lisp2mathExpr e1) (lisp2mathExpr e2)
_ -> error $ "lisp2mathExpr: "++xs
else error $ "lisp2mathExpr: malformed input '("++xs++"'"
lisp2mathExpr xs = case splitOn "/" xs of
[num,den] -> lisp2mathExpr $ "(/ "++num++" "++den++")"
_ -> case readMaybe xs :: Maybe Double of
Just x -> ELit $ toRational x
Nothing -> ELeaf xs
lisp2vars :: String -> [String]
lisp2vars = nub . lisp2varsNoNub
lisp2varsNoNub :: String -> [String]
lisp2varsNoNub lisp
= sort
$ filter (\x -> x/="("
&& x/=")"
&& (x `notElem` binOpList)
&& (x `notElem` monOpList)
&& (head x `notElem` ("1234567890"::String))
)
$ tokenize lisp :: [String]
where
tokenize :: String -> [String]
tokenize = words . concat . map go
where
go '(' = " ( "
go ')' = " ) "
go x = [x]
lispHasRepeatVars :: String -> Bool
lispHasRepeatVars lisp = length (lisp2vars lisp) /= length (lisp2varsNoNub lisp)
readMaybe :: Read a => String -> Maybe a
readMaybe = fmap fst . listToMaybe . reads
groupByParens :: String -> [String]
groupByParens str = go 0 str [] []
where
go 0 (' ':xs) [] ret = go 0 xs [] ret
go 0 (' ':xs) acc ret = go 0 xs [] (ret++[acc])
go 0 (')':xs) acc ret = go 0 xs [] (ret++[acc])
go i (')':xs) acc ret = go (i1) xs (acc++")") ret
go i ('(':xs) acc ret = go (i+1) xs (acc++"(") ret
go i (x :xs) acc ret = go i xs (acc++[x]) ret
go _ [] acc ret = ret++[acc]