{-# LANGUAGE DeriveDataTypeable #-}
-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Domain.Math.Simplification
   ( Simplify(..), SimplifyConfig(..)
   , simplifyConfig
   , Simplified, simplified, liftS, liftS2
   , simplifyRule
   , collectLikeTerms, mergeAlike, distribution, constantFolding
   , mergeAlikeSum, mergeAlikeProduct
   ) where

import Control.Monad
import Data.List
import Data.Maybe
import Data.Typeable
import Domain.Math.CleanUp (smart)
import Domain.Math.Data.Relation
import Domain.Math.Expr
import Domain.Math.Numeric.Views
import Domain.Math.SquareRoot.Views
import Ideas.Common.Library hiding (simplify, simplifyWith, (.*.), (./.))
import Ideas.Utils.Uniplate
import qualified Ideas.Common.View as View

data SimplifyConfig = SimplifyConfig
  { withSmartConstructors  :: Bool
  , withMergeAlike         :: Bool
  , withDistribution       :: Bool
  , withSimplifySquareRoot :: Bool
  , withConstantFolding    :: Bool
  }

class Simplify a where
   simplifyWith :: SimplifyConfig -> a -> a
   simplify :: a -> a
   simplify = simplifyWith simplifyConfig

simplifyConfig :: SimplifyConfig
simplifyConfig = SimplifyConfig True True True True True

instance Simplify a => Simplify (Context a) where
   simplifyWith cfg = changeInContext $ simplifyWith cfg

instance Simplify a => Simplify (Equation a) where
   simplifyWith cfg = fmap $ simplifyWith cfg

instance Simplify a => Simplify (Relation a) where
   simplifyWith cfg = fmap $ simplifyWith cfg

instance Simplify a => Simplify [a] where
   simplifyWith cfg = fmap $ simplifyWith cfg

instance Simplify Expr where
   simplifyWith cfg = let optional p f = if p then f else id in
       optional (withSmartConstructors cfg)  (transform smart)
     . optional (withMergeAlike cfg)         mergeAlike
     . optional (withDistribution cfg)       distribution
     . optional (withSimplifySquareRoot cfg) (View.simplify
                                               (squareRootViewWith rationalView))
     . optional (withConstantFolding cfg)    constantFolding

instance Simplify a => Simplify (Rule a) where
   simplifyWith cfg = doAfter (simplifyWith cfg) -- by default, simplify afterwards

data Simplified a = S a deriving (Eq, Ord, Typeable)

instance Show a => Show (Simplified a) where
   show (S x) = show x

instance (Read a, Simplify a) => Read (Simplified a) where
   readsPrec n = map (mapFirst simplified) . readsPrec n

instance (Num a, Simplify a) => Num (Simplified a) where
   (+)         = liftS2 (+)
   (*)         = liftS2 (*)
   (-)         = liftS2 (-)
   negate      = liftS negate
   abs         = liftS abs
   signum      = liftS signum
   fromInteger = simplified . fromInteger

instance (Fractional a, Simplify a) => Fractional (Simplified a) where
   (/)          = liftS2 (/)
   recip        = liftS recip
   fromRational = simplified . fromRational

instance (Floating a, Simplify a) => Floating (Simplified a) where
   pi      = simplified pi
   sqrt    = liftS  sqrt
   (**)    = liftS2 (**)
   logBase = liftS2 logBase
   exp     = liftS exp
   log     = liftS log
   sin     = liftS sin
   tan     = liftS tan
   cos     = liftS cos
   asin    = liftS asin
   atan    = liftS atan
   acos    = liftS acos
   sinh    = liftS sinh
   tanh    = liftS tanh
   cosh    = liftS cosh
   asinh   = liftS asinh
   atanh   = liftS atanh
   acosh   = liftS acosh

instance (Simplify a, IsTerm a) => IsTerm (Simplified a) where
   toTerm (S x) = toTerm x
   fromTerm     = fmap simplified . fromTerm

instance (Reference a, Simplify a) => Reference (Simplified a)

simplified :: Simplify a => a -> Simplified a
simplified = S . simplify

liftS :: Simplify a => (a -> a) -> Simplified a -> Simplified a
liftS f (S x) = simplified (f x)

liftS2 :: Simplify a => (a -> a -> a) -> Simplified a -> Simplified a -> Simplified a
liftS2 f (S x) (S y) = simplified (f x y)

simplifyRule :: Simplify a => Rule a
simplifyRule = simplify (idRule "simplify")

-------------------------------------------------------------
-- Distribution of constants

distribution :: Expr -> Expr
distribution = descend distribution . f
 where
  f expr =
   fromMaybe expr $
   case expr of
      a :*: b -> do
         (x, y) <- match plusView a
         r      <- match rationalView b
         return $ (fromRational r .*. x) .+. (fromRational r .*. y)
       `mplus` do
         r      <- match rationalView a
         (x, y) <- match plusView b
         return $ (fromRational r .*. x) .+. (fromRational r .*. y)
      a :/: b -> do
         xs <- match sumView a
         guard (length xs > 1)
         return $ build sumView $ map (./. b) xs
      _ -> Nothing

-------------------------------------------------------------
-- Constant folding

-- Not an efficient implementation: could be improved if necessary
constantFolding :: Expr -> Expr
constantFolding expr =
   case match rationalView expr of
      Just r  -> fromRational r
      Nothing -> descend constantFolding expr

----------------------------------------------------------------------
-- merge alike for sums and products

-- Todo: combine with mergeAlike (subtle differences)
collectLikeTerms :: Expr -> Expr
collectLikeTerms = View.simplifyWith f sumView
 where
   f = mergeAlikeSum . map (View.simplifyWith (second mergeAlikeProduct) productView)

mergeAlike :: Expr -> Expr
mergeAlike a =
   case (match sumView a, match productView a) of
      (Just xs, _) | length xs > 1 ->
         build sumView (sort $ mergeAlikeSum $ map mergeAlike xs)
      (_, Just (b, ys)) | length (filter (/= 1) ys) > 1 ->
         build productView (b, sort $ mergeAlikeProduct $ map mergeAlike ys)
      _ -> a

mergeAlikeProduct :: [Expr] -> [Expr]
mergeAlikeProduct ys = f [ (match rationalView y, y) | y <- ys ]
 where
   f []                    = []
   f ((Nothing  , e):xs)   = e:f xs
   f ((Just r   , _):xs)   =
      let cs   = r : [ c | (Just c, _) <- xs ]
          rest = [ x | (Nothing, x) <- xs ]
      in build rationalView (product cs):rest

mergeAlikeSum :: [Expr] -> [Expr]
mergeAlikeSum xs = rec [ (Just $ pm 1 x, x) | x <- xs ]
 where
   pm :: Rational -> Expr -> (Rational, Expr)
   pm r (e1 :*: e2) = case (match rationalView e1, match rationalView e2) of
                         (Just r1, _) -> pm (r*r1) e2
                         (_, Just r1) -> pm (r*r1) e1
                         _           -> (r, e1 .*. e2)
   pm r (Negate e) = pm (negate r) e
   pm r e = case match rationalView e of
               Just r1 -> (r*r1, Nat 1)
               Nothing -> (r, e)

   rec [] = []
   rec ((Nothing, e):ys) = e:rec ys
   rec ((Just (r, a), e):ys) = new:rec rest
    where
      (js, rest) = partition (maybe False ((==a) . snd) . fst) ys
      rs  = r:map fst (mapMaybe fst js)
      new | null js   = e
          | otherwise = build rationalView (sum rs) .*. a