{-# LANGUAGE FlexibleInstances,FlexibleContexts,MultiWayIf,CPP #-}
module Herbie.MathInfo
    where

import Class
import DsBinds
import DsMonad
import ErrUtils
import GhcPlugins hiding (trace)
import Unique
import MkId
import PrelNames
import UniqSupply
import TcRnMonad
import TcSimplify
import Type

import Control.Monad
import Control.Monad.Except
import Control.Monad.Trans
import Data.Char
import Data.List
import Data.Maybe
import Data.Ratio

import Herbie.CoreManip
import Herbie.MathExpr

import Prelude
import Show

-- import Debug.Trace hiding (traceM)
trace a b = b
traceM a = return ()

--------------------------------------------------------------------------------

-- | The fields of this type correspond to the sections of a function type.
--
-- Must satisfy the invariant that every class in "getCxt" has an associated dictionary in "getDicts".
data ParamType = ParamType
    { getQuantifier :: [Var]
    , getCxt        :: [Type]
    , getDicts      :: [CoreExpr]
    , getParam      :: Type
    }

-- | This type is a simplified version of the CoreExpr type.
-- It only supports math expressions.
-- We first convert a CoreExpr into a MathInfo,
-- perform all the manipulation on the MathExpr within the MathInfo,
-- then use the information in MathInfo to convert the MathExpr back into a CoreExpr.
data MathInfo = MathInfo
    { getMathExpr   :: MathExpr
    , getParamType  :: ParamType
    , getExprs      :: [(String,Expr Var)]
        -- ^ the fst value is the unique name assigned to non-mathematical expressions
        -- the snd value is the expression
    }

-- | Pretty print a math expression
pprMathInfo :: MathInfo -> String
pprMathInfo mathInfo = go 1 False $ getMathExpr mathInfo
    where
        isLitOrLeaf :: MathExpr -> Bool
        isLitOrLeaf (ELit _ ) = True
        isLitOrLeaf (ELeaf _) = True
        isLitOrLeaf _         = False

        go :: Int -> Bool -> MathExpr -> String
        go i b e = if b && not (isLitOrLeaf e)
            then "("++str++")"
            else str
            where
                str = case e of
                    EMonOp op e1 -> op++" "++ go i True e1

                    EBinOp op e1 e2 -> go i parens1 e1++" "++op++" "++go i parens2 e2
                        where
                            parens1 = case e1 of
                                (EBinOp op' _ _) -> op/=op'
                                _ -> True

                            parens2 = case e2 of
                                (EBinOp op' _ _) -> op/=op'
                                _ -> True

                    ELit l -> if toRational (floor l) == l
                        then if length (show (floor l :: Integer)) < 10
                            then show (floor l :: Integer)
                            else show (fromRational l :: Double)
                        else show (fromRational l :: Double)

                    ELeaf l -> case lookup l $ getExprs mathInfo of
                        Just (Var _) -> l
                        _            -> "???"

                    EIf cond e1 e2 -> "if "++go i False cond++"\n"
                        ++white++"then "++go (i+1) False e1++"\n"
                        ++white++"else "++go (i+1) False e2
                        where
                            white = replicate (4*i) ' '

-- If the given expression is a math expression,
-- returns the type of the variable that the math expression operates on.
varTypeIfValidExpr :: CoreExpr -> Maybe Type
varTypeIfValidExpr e = case e of

    -- might be a binary math operation
    (App (App (App (App (Var v) (Type t)) _) _) _) -> if var2str v `elem` binOpList
        then if isValidType t
            then Just t
            else Nothing
        else Nothing

    -- might be a unary math operation
    (App (App (App (Var v) (Type t)) _) _) -> if var2str v `elem` monOpList
        then if isValidType t
            then Just t
            else Nothing
        else Nothing

    -- first function is anything else means that we're not a math expression
    _ -> Nothing

    where
        isValidType :: Type -> Bool
        isValidType t = isTyVarTy t || case splitTyConApp_maybe t of
            Nothing -> True
            Just (tyCon,_) -> tyCon == floatTyCon || tyCon == doubleTyCon

-- | Converts a CoreExpr into a MathInfo
mkMathInfo :: DynFlags -> [Var] -> Type -> Expr Var -> Maybe MathInfo
mkMathInfo dflags dicts bndType e = case varTypeIfValidExpr e of
        Nothing -> Nothing
        Just t -> if mathExprDepth getMathExpr>1 && lispHasRepeatVars (mathExpr2lisp getMathExpr)
            then Just $ MathInfo
                getMathExpr
                ParamType
                    { getQuantifier = quantifier
                    , getCxt = cxt
                    , getDicts = map Var dicts
                    , getParam = t
                    }
                exprs
            else Nothing

    where
        (getMathExpr,exprs) = go e []

        -- this should never return Nothing if validExpr is not Nothing
        (quantifier,unquantified) = extractQuantifiers bndType
        (cxt,uncxt) = extractContext unquantified

        -- recursively converts the `Expr Var` into a MathExpr and a dictionary
        go :: Expr Var
           -> [(String,Expr Var)]
           -> (MathExpr
              ,[(String,Expr Var)]
              )

        -- we need to special case the $ operator for when MathExpr is run before any rewrite rules
        go e@(App (App (App (App (Var v) (Type _)) (Type _)) a1) a2) exprs
            = if var2str v == "$"
                then go (App a1 a2) exprs
                else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])

        -- polymorphic literals created via fromInteger
        go e@(App (App (App (Var v) (Type _)) dict) (Lit l)) exprs
            = (ELit $ lit2rational l, exprs)

        -- polymorphic literals created via fromRational
        go e@(App (App (App (Var v) (Type _)) dict)
             (App (App (App (Var _) (Type _)) (Lit l1)) (Lit l2))) exprs
            = (ELit $ lit2rational l1 / lit2rational l2, exprs)

        -- non-polymorphic literals
        go e@(App (Var _) (Lit l)) exprs
            = (ELit $ lit2rational l, exprs)

        -- binary operators
        go e@(App (App (App (App (Var v) (Type _)) dict) a1) a2) exprs
            = if var2str v `elem` binOpList
                then let (a1',exprs1) = go a1 []
                         (a2',exprs2) = go a2 []
                     in ( EBinOp (var2str v) a1' a2'
                        , exprs++exprs1++exprs2
                        )
                else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])

        -- unary operators
        go e@(App (App (App (Var v) (Type _)) dict) a) exprs
            = if var2str v `elem` monOpList
                then let (a',exprs') = go a []
                     in ( EMonOp (var2str v) a'
                        , exprs++exprs'
                        )
                else (ELeaf $ expr2str dflags e,(expr2str dflags e,e):exprs)

        -- everything else
        go e exprs = (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)])

-- | Converts a MathInfo back into a CoreExpr
mathInfo2expr :: ModGuts -> MathInfo -> ExceptT String CoreM CoreExpr
mathInfo2expr guts herbie = go (getMathExpr herbie)
    where
        pt = getParamType herbie

        -- binary operators
        go (EBinOp opstr a1 a2) = do
            a1' <- go a1
            a2' <- go a2
            f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt)
            return $ App (App f a1') a2'

        -- unary operators
        go (EMonOp opstr a) = do
            a' <- go a
            f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt)
            castToType
                (getDicts pt)
                (getParam pt)
                $ App f a'

        -- if statements
        go (EIf cond a1 a2) = do
            cond' <- go cond >>= castToType (getDicts pt) boolTy
            a1' <- go a1
            a2' <- go a2

            wildUniq <- getUniqueM
            let wildName = mkSystemName wildUniq (mkVarOcc "wild")
                wildVar = mkLocalVar VanillaId wildName boolTy vanillaIdInfo

            return $ Case
                cond'
                wildVar
                (getParam pt)
                [ (DataAlt falseDataCon, [], a2')
                , (DataAlt trueDataCon, [], a1')
                ]

        -- leaf is a numeric literal
        go (ELit r) = do
            fromRationalExpr <- getDecoratedFunction guts "fromRational" (getParam pt) (getDicts pt)

            integerTyCon <- lookupTyCon integerTyConName
            let integerTy = mkTyConTy integerTyCon

            ratioTyCon <- lookupTyCon ratioTyConName
            tmpUniq <- getUniqueM
            let tmpName = mkSystemName tmpUniq (mkVarOcc "a")
                tmpVar = mkTyVar tmpName liftedTypeKind
                tmpVarT = mkTyVarTy tmpVar
                ratioConTy = mkForAllTy tmpVar $ mkFunTys [tmpVarT,tmpVarT] $ mkAppTy (mkTyConTy ratioTyCon) tmpVarT
                ratioConVar = mkGlobalVar VanillaId ratioDataConName ratioConTy vanillaIdInfo

            return $ App
                fromRationalExpr
                (App
                    (App
                        (App
                            (Var ratioConVar )
                            (Type integerTy)
                        )
                        (Lit $ LitInteger (numerator r) integerTy)
                    )
                    (Lit $ LitInteger (denominator r) integerTy)
                )

        -- leaf is any other expression
        go (ELeaf str) = do
            dflags <- getDynFlags
            return $ case lookup str (getExprs herbie) of
                Just x -> x
                Nothing -> error $ "mathInfo2expr: var " ++ str ++ " not in scope"
                    ++"; in scope vars="++show (nub $ map fst $ getExprs herbie)