{-# LANGUAGE DeriveAnyClass,DeriveGeneric #-}
module Herbie.MathExpr
    where

import Control.DeepSeq
import Data.List
import Data.List.Split
import Data.Maybe
import GHC.Generics

import Debug.Trace
import Prelude
ifThenElse True t f = t
ifThenElse False t f = f

-------------------------------------------------------------------------------
-- constants that define valid math expressions

monOpList =
    [ "cos"
    , "sin"
    , "tan"
    , "acos"
    , "asin"
    , "atan"
    , "cosh"
    , "sinh"
    , "tanh"
    , "exp"
    , "log"
    , "sqrt"
    , "abs"
    , "size"
    ]

binOpList = [ "^", "**", "^^", "/", "-", "expt" ] ++ commutativeOpList
commutativeOpList = [ "*", "+"] -- , "max", "min" ]

--------------------------------------------------------------------------------

-- | Stores the AST for a math expression in a generic form that requires no knowledge of Core syntax.
data MathExpr
    = EBinOp String MathExpr MathExpr
    | EMonOp String MathExpr
    | EIf MathExpr MathExpr MathExpr
    | ELit Rational
    | ELeaf String
    deriving (Show,Eq,Generic,NFData)

instance Ord MathExpr where
    compare (ELeaf _) (ELeaf _) = EQ
    compare (ELeaf _) _         = LT

    compare (ELit r1) (ELit r2) = compare r1 r2
    compare (ELit _ ) (ELeaf _) = GT
    compare (ELit _ ) _         = LT

    compare (EMonOp op1 e1) (EMonOp op2 e2) = case compare op1 op2 of
        EQ -> compare e1 e2
        x  -> x
    compare (EMonOp _ _) (ELeaf _) = GT
    compare (EMonOp _ _) (ELit  _) = GT
    compare (EMonOp _ _) _         = LT

    compare (EBinOp op1 e1a e1b) (EBinOp op2 e2a e2b) = case compare op1 op2 of
        EQ -> case compare e1a e2a of
            EQ -> compare e1b e2b
            _  -> EQ
        _ -> EQ
    compare (EBinOp _ _ _) _ = LT

-- | Converts all Haskell operators in the MathExpr into Herbie operators
haskellOpsToHerbieOps :: MathExpr -> MathExpr
haskellOpsToHerbieOps = go
    where
        go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2)
            where
                op' = case op of
                    "**"   -> "expt"
                    "^^"   -> "expt"
                    "^"    -> "expt"
                    x      -> x

        go (EMonOp op e1) = EMonOp op' (go e1)
            where
                op' = case op of
                    "size" -> "abs"
                    x      -> x

        go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
        go x = x

-- | Converts all Herbie operators in the MathExpr into Haskell operators
herbieOpsToHaskellOps :: MathExpr -> MathExpr
herbieOpsToHaskellOps = go
    where
        go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2)
            where
                op' = case op of
                    "^"    -> "**"
                    "expt" -> "**"
                    x      -> x

        go (EMonOp "sqr" e1) = EBinOp "*" (go e1) (go e1)
        go (EMonOp op e1) = EMonOp op' (go e1)
            where
                op' = case op of
                    "-" -> "negate"
                    "abs" -> "size"
                    x   -> x

        go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
        go x = x

-- | Replace all the variables in the MathExpr with canonical names (x0,x1,x2...)
-- and reorder commutative binary operations.
-- This lets us more easily compare MathExpr's based on their structure.
-- The returned map lets us convert the canoncial MathExpr back into the original.
toCanonicalMathExpr :: MathExpr -> (MathExpr,[(String,String)])
toCanonicalMathExpr e = go [] e
    where
        go :: [(String,String)] -> MathExpr -> (MathExpr,[(String,String)])
        go acc (EBinOp op e1 e2) = (EBinOp op e1' e2',acc2')
            where
                (e1_,e2_) = if op `elem` commutativeOpList
                    then (min e1 e2,max e1 e2)
                    else (e1,e2)

                (e1',acc1') = go acc e1_
                (e2',acc2') = go acc1' e2_

        go acc (EMonOp op e1) = (EMonOp op e1', acc1')
            where
                (e1',acc1') = go acc e1
        go acc (ELit r) = (ELit r,acc)
        go acc (ELeaf str) = (ELeaf str',acc')
            where
                (acc',str') = case lookup str acc of
                    Nothing -> ((str,"herbie"++show (length acc)):acc, "herbie"++show (length acc))
                    Just x -> (acc,x)

-- | Convert a canonical MathExpr into its original form.
--
-- FIXME:
-- A bug in Herbie causes it to sometimes output infinities,
-- which break this function and cause it to error.
fromCanonicalMathExpr :: (MathExpr,[(String,String)]) -> MathExpr
fromCanonicalMathExpr (e,xs) = go e
    where
        xs' = map (\(a,b) -> (b,a)) xs

        go (EMonOp op e1) = EMonOp op (go e1)
        go (EBinOp op e1 e2) = EBinOp op (go e1) (go e2)
        go (EIf (EBinOp "<" _ (ELeaf "-inf.0")) e1 e2) = go e2 -- FIXME: added due to bug above
        go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2)
        go (ELit r) = ELit r
        go (ELeaf str) = case lookup str xs' of
            Just x -> ELeaf x
            Nothing -> error $ "fromCanonicalMathExpr: str="++str++"; xs="++show xs'

-- | Calculates the maximum depth of the AST.
mathExprDepth :: MathExpr -> Int
mathExprDepth (EBinOp _ e1 e2) = 1+max (mathExprDepth e1) (mathExprDepth e2)
mathExprDepth (EMonOp _ e1   ) = 1+mathExprDepth e1
mathExprDepth _ = 0

--------------------------------------------------------------------------------
-- functions for manipulating math expressions in lisp form

getCanonicalLispCmd :: MathExpr -> (String,[(String,String)])
getCanonicalLispCmd me = (mathExpr2lisp me',varmap)
    where
        (me',varmap) = toCanonicalMathExpr me

fromCanonicalLispCmd :: (String,[(String,String)]) -> MathExpr
fromCanonicalLispCmd (lisp,varmap) = fromCanonicalMathExpr (lisp2mathExpr lisp,varmap)

-- | Converts MathExpr into a lisp command suitable for passing to Herbie
mathExpr2lisp :: MathExpr -> String
mathExpr2lisp = go
    where
        go (EBinOp op a1 a2) = "("++op++" "++go a1++" "++go a2++")"
        go (EMonOp op a) = "("++op++" "++go a++")"
        go (EIf cond e1 e2) = "(if "++go cond++" "++go e1++" "++go e2++")"
        go (ELeaf e) = e
        go (ELit r) = if (toRational (floor r::Integer) == r)
            then show (floor r :: Integer)
            else show (fromRational r :: Double)

-- | Converts a lisp command into a MathExpr
lisp2mathExpr :: String -> MathExpr
lisp2mathExpr ('-':xs) = EMonOp "negate" (lisp2mathExpr xs)
lisp2mathExpr ('(':xs) = if length xs > 1 && last xs==')'
    then case groupByParens $ init xs of
        [op,e1]             -> EMonOp op (lisp2mathExpr e1)
        [op,e1,e2]          -> EBinOp op (lisp2mathExpr e1) (lisp2mathExpr e2)
        ["if",cond,e1,e2]   -> EIf (lisp2mathExpr cond) (lisp2mathExpr e1) (lisp2mathExpr e2)
        _                   -> error $ "lisp2mathExpr: "++xs
    else error $ "lisp2mathExpr: malformed input '("++xs++"'"
lisp2mathExpr xs = case splitOn "/" xs of
    [num,den] -> lisp2mathExpr $ "(/ "++num++" "++den++")"
    _ -> case readMaybe xs :: Maybe Double of
        Just x -> ELit $ toRational x
        Nothing -> ELeaf xs

-- | Extracts all the variables from the lisp commands with no duplicates.
lisp2vars :: String -> [String]
lisp2vars = nub . lisp2varsNoNub

-- | Extracts all the variables from the lisp commands.
-- Each variable occurs once in the output for each time it occurs in the input.
lisp2varsNoNub :: String -> [String]
lisp2varsNoNub lisp
    = sort
    $ filter (\x -> x/="("
                 && x/=")"
                 && (x `notElem` binOpList)
                 && (x `notElem` monOpList)
                 && (head x `notElem` ("1234567890"::String))
             )
    $ tokenize lisp :: [String]
    where
        -- We just need to add spaces around the parens before calling "words"
        tokenize :: String -> [String]
        tokenize = words . concat . map go
            where
                go '(' = " ( "
                go ')' = " ) "
                go x   = [x]

lispHasRepeatVars :: String -> Bool
lispHasRepeatVars lisp = length (lisp2vars lisp) /= length (lisp2varsNoNub lisp)

-------------------------------------------------------------------------------
-- utilities

readMaybe :: Read a => String -> Maybe a
readMaybe = fmap fst . listToMaybe . reads

-- | Given an expression, break it into tokens only outside parentheses
groupByParens :: String -> [String]
groupByParens str = go 0 str [] []
    where
        go 0 (' ':xs) []  ret = go 0     xs []         ret
        go 0 (' ':xs) acc ret = go 0     xs []         (ret++[acc])
        go 0 (')':xs) acc ret = go 0     xs []         (ret++[acc])
        go i (')':xs) acc ret = go (i-1) xs (acc++")") ret
        go i ('(':xs) acc ret = go (i+1) xs (acc++"(") ret
        go i (x  :xs) acc ret = go i     xs (acc++[x]) ret
        go _ []       acc ret = ret++[acc]