-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- Defines a set of Ideas rules that may be used for rewriting expressions.
--
-----------------------------------------------------------------------------

module Recognize.Strategy.Rules
( liftRule
, removeTimes
, distributeDivision
, mergeNums
, mergeVars) where

import Control.Monad
import Data.List
import Data.Maybe
import Domain.Math.Expr             hiding (sumView)
import Domain.Math.Data.Relation
import Domain.Math.Numeric.Views
import Domain.Math.Polynomial.Rules
import Domain.Math.Simplification   hiding (mergeAlikeSum)
import Ideas.Common.Library         hiding ((.*.))
import Recognize.Expr.Functions
import Recognize.Expr.Normalform
import Recognize.Strategy.Views

-- | Lifts any rule that may be applied to an equation to a rule that can be applied to relations
liftRule :: Rule (Equation Expr) -> Rule (Relation Expr)
liftRule r = makeRule (getId r) $ \rel -> do
  let relType = relationType rel
  (a :==: b) <- apply r (leftHandSide rel :==: rightHandSide rel)
  return (makeType relType a b)

-- | Allows a rewrite that attempts to remove multiplications by division
--
-- >>> 2 * 2 + 4 * a = 8
-- >>> 2 + 2 * a = 4
removeTimes :: Rule (Relation Expr)
removeTimes = doAfter (fmap (collectLikeTerms . distributeAll)) $
   describe "remove times" $
   ruleTrans ("linear", "remove-times") $
   inputWith arg timesDivisionRule
 where
  -- Use transList to allow removal of factor from different variables
   arg = transList $ \eq -> do
      xs <- matchM sumView (leftHandSide eq)
      ys <- matchM sumView (rightHandSide eq)
      -- also consider parts without variables
      -- (but at least one participant should have a variable)
      zs <- forM (xs ++ ys) $ \a -> return (hasSomeVar a, a)
      let f (b, e) = do
             (this, _) <- match (timesView >>> first integerView) e
             return (b, this)

          (bs, ns) = unzip (mapMaybe f zs)
      let sns = subsequences ns
      as <- filter (not . null) sns
      return (fromInteger $ foldr1 lcm as)

-- | Rule wrapper over `distributeDivisionT`
distributeDivision :: Rule Expr
distributeDivision = makeRule "distr-division" distributeDivisionT

-- | Collect (natural) numbers
mergeNums :: Rule Expr
mergeNums = describe "merge numbers (including naturals)" $
   ruleMaybe ("linear", "merge.num") $ \old -> do
      let new  = build sumView $ mergeAlikeSum (\e -> isNat e || isNumber e) (from sumView old)
      guard (nfComAssoc old /= nfComAssoc new)
      return new

-- | Collect variables
mergeVars :: Rule Expr
mergeVars = describe "merge variables" $
   ruleMaybe ("linear", "merge.var") $ \old -> do
      let new  = build sumView $ mergeAlikeSum isVar (from sumView old)
          f    = maybe 0 length . match sumView
      guard (f old > f new)
      return new

mergeAlikeSum :: (Expr -> Bool) -> [Expr] -> [Expr]
mergeAlikeSum p xs = rec [ (pm 1 x, x) | x <- xs ]
 where
   rec [] = []
   rec (((r, a), e):ys) = new:rec rest
    where
      (js, rest) = partition (\((_,a2),_) -> a2 == a && p a) ys
      rs  = r:map (fst . fst) js
      new | null js   = e
          | otherwise = build rationalView (toRational $ sum rs) .*. a

pm :: Double -> Expr -> (Double, Expr)
pm r (e1 :*: e2) = case (match doubleView e1, match doubleView e2) of
                        (Just r1, _) -> pm (r*r1) e2
                        (_, Just r1) -> pm (r*r1) e1
                        _           -> (r, e1 .*. e2)
pm r (Negate e) = pm (negate r) e
pm r e = case match doubleView e of
            Just r1 -> (r*r1, Nat 1)
            Nothing -> (r, e)

timesDivisionRule :: Functor f => ParamTrans Expr (f Expr)
timesDivisionRule = parameter1 factorRef $ \a -> unlessZero a . fmap (\b -> b :*: (1 :/: a))

unlessZero :: Expr -> a -> Maybe a
unlessZero e a = do
   r <- matchM rationalView e
   guard (r /= 0)
   return a

factorRef, termRef :: Ref Expr
factorRef = makeRef "factor"
termRef   = makeRef "term"

-- Only used for cleaning up
distributeAll :: Expr -> Expr
distributeAll expr =
   case expr of
      e1 :*: e2 -> let as = fromMaybe [e1] (match sumView e1)
                       bs = fromMaybe [e2] (match sumView e2)
                   in build sumView [ a .*. b | a <- as, b <- bs ]
      _ -> expr