{-# LANGUAGE PatternGuards #-}
-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- In this module we have defined several functions that produce some kind of normal form for an expression
-- Not all functions are used and they also often do not normalize as far as we would like.
-- Note that not all normalization functions defined in this project are defined in this module.
-- Sometimes specific normalizing functions are defined for a tool.
--
-----------------------------------------------------------------------------

module Recognize.Expr.Normalform
   ( nf, nf4, nfComAssoc, nfCom
   , rewriteSqrt, distributeExponent
   , ceilingExpr, floorExpr
   , roundDouble, doubleRoundedView
   , (===)
   ) where

import Util.Cache
import Data.Function
import Data.List
import Domain.Algebra.SmartGroup
import Domain.Math.Expr
import Domain.Math.Numeric.Views
import Domain.Math.Polynomial.Views
import Ideas.Common.Id
import Ideas.Common.Rewriting
import Ideas.Common.View as IV
import Ideas.Utils.Prelude
import Ideas.Utils.Uniplate

distributeExponent :: Expr -> Expr
distributeExponent e@(Sym s1 [x,i]) | isPowerSymbol s1
                    , Sym s2 [y,j] <- x
                    , isPowerSymbol s2
                    = Sym s1 [y,i .*. j]
distributeExponent e = e

rewriteSqrt :: Expr -> Expr
rewriteSqrt (Sqrt e) = Sym powerSymbol [e,1/2]
rewriteSqrt e = e

-- equality under normalisation
(===) :: Expr -> Expr -> Bool
a === b = nf2 a == nf2 b

-- | Note that the Ord instance for Expr is derived.
-- Because of this:   1:+:1 < 1:-:1
-- Therefore, comparison of Expr is not reliable when performed after nf.
-- since 3/2 cannot be normalized any further (but is bigger then 2)
nf :: Expr -> Expr
nf = cached "nf" $ \expr ->
   case expr of
      Sym s xs -> Sym s (map nf xs)
      -- nf (Sqrt e) = Sqrt (nf e)    -- < prevents actual simplification of square roots
      _ -> transform (simplify (polyViewWith rationalApproxView)) expr

nf2 :: Expr -> Expr
nf2 (Sym s xs) = Sym s (map nf2 xs)
nf2 e = simplify rationalApproxView e

-- | Simplifies with a certain precision
nf3 :: Int -> Expr -> Expr
nf3 n (Sym s xs) = Sym s $ map (nf3 n) xs
nf3 n e = simplify (doubleRoundedView (roundDouble n)) e

-- | Simplified with a certain precision if no variable present otherwise calls `nf`
nf4 :: Int -> Expr -> Expr
nf4 n e
    | hasSomeVar e = nfComAssoc $ nf e
    | otherwise = nf3 n e

doubleRoundedView :: (Double -> Double) -> View Expr Double
doubleRoundedView round = "num.double.rounded" @> doubleView >>> makeView (Just. round) id  --  . fix (matchDouble)) fromDouble

ceilingExpr :: Expr -> Expr
ceilingExpr (Number d) = Nat $ ceiling d
ceilingExpr e = e

floorExpr :: Expr -> Expr
floorExpr (Number d) = Nat $ floor d
floorExpr e = e

-- | Normalform for associativity
nfComAssoc :: Expr -> Expr
nfComAssoc = cached "nfComAssoc" $ \expr ->
   case expr of
      -- Rewrites a Number to a division
      Number _ -> nf expr
      _ ->
         case (from sumView expr, from productView expr) of
            (xs, _) | length xs > 1 ->
               to sumView $ sortBy (compare `on` nf) (map nfComAssoc xs)
            (_, (b, xs)) | length xs > 1 ->
               to productView (b, sortBy (compare `on` nf) (map nfComAssoc xs))
            _ ->
               descend nfComAssoc expr

-- | normal form for commutativity +/*
nfCom :: Expr -> Expr
nfCom = cached "nfCom" $ \expr ->
   case expr of
      _ :+: _ -> sum (sort (map nfCom (collect expr)))
       where
         collect (x :+: y) = collect x ++ collect y
         collect (x :-: y) = collect x ++ map neg (collect y)
         collect (Negate x) = map neg (collect x)
         collect a = [a]
      x :*: y  -- we need to collect here as well?  (3 * 2 * 1 -> 2 * 3 * 1)
         | x' <= y'  -> x' :*: y'
         | otherwise -> y' :*: x'
       where
         x' = nfCom x
         y' = nfCom y
      Number _ -> simplify rationalApproxView expr
      _ -> descend nfCom expr

-- | Round a double with a specified precision
--
-- Unlike the `round` function in prelude, this function will round up if the decimal is >=5 and otherwise down
roundDouble :: Int -> Double -> Double
roundDouble n d = fromIntegral (roundNearest (d * 10Prelude.^n)) / 10Prelude.^n

-- The `round` function in prelude rounds with regard to the equidistance of the argument
roundNearest :: (RealFrac a, Integral b) => a -> b
roundNearest a = let (n,r) = properFraction a
                  in if r >= 0.5 then n + 1 else n