-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Domain.Math.Polynomial.Balance where

import Control.Monad
import Data.Function
import Data.Maybe
import Data.Ratio
import Data.Semigroup
import Domain.Math.Data.Relation
import Domain.Math.Data.WithBool
import Domain.Math.Equation.BalanceRules
import Domain.Math.Equation.Views
import Domain.Math.Expr
import Domain.Math.Numeric.Views
import Domain.Math.Polynomial.BalanceUtils
import Domain.Math.Polynomial.BuggyBalance
import Domain.Math.Polynomial.Examples
import Domain.Math.Polynomial.Generators
import Domain.Math.Polynomial.Rules        (conditionVarsRHS, flipEquation)
import Domain.Math.Polynomial.Views
import Ideas.Common.Examples
import Ideas.Common.Library                hiding ((.*.))
import Ideas.Common.Strategy.Legacy
import Ideas.Utils.Prelude                 (fixpoint)
import Ideas.Utils.Uniplate
import Prelude                             hiding ((<*>))
import Test.QuickCheck                     (sized)

------------------------------------------------------------
-- Exercise

balanceExercise :: Exercise (WithBool (Equation Expr))
balanceExercise = makeExercise
   { exerciseId    = describe "Solve a linear equation using only balance rules." $
                     newId "algebra.equations.linear.balance"
   , status        = Provisional
   , parser        = parseBoolEqExpr
   , similarity    = withoutContext ((==) `on` cleaner)
   , equivalence   = withoutContext (viewEquivalent eqView)
   , suitable      = predicateView (traverseView linearEquationView)
   , ready         = predicateView (traverseView (equationSolvedWith mixedFractionNF))
                     <||> predicateView (traverseView (equationSolvedWith rationalNF))
                     <||> predicateView (traverseView (equationSolvedWith doubleNF))
   , strategy      = balanceStrategy
   , extraRules    = map use buggyBalanceRules ++ map use buggyBalanceExprRules
   , ruleOrdering  = ruleOrderingWith (balanceOrder ++ buggyPriority)
   , navigation    = termNavigator
   , examples      = fmap singleton linearExamples
                     <> forTesting (random (liftM2 (\a b -> singleton (a :==: b)) (sized linearGen) (sized linearGen)))
   }

balanceOrder :: [Id]
balanceOrder =
   [ getId introTrue, getId introFalse
   , getId removeDivision, getId collect
   , getId varRightMinus, getId varRightPlus
   , getId conLeftMinus, getId conLeftPlus
   , getId varLeftMinus, getId varLeftPlus   -- prefer variable to left
   , getId conRightMinus, getId conRightPlus -- or constant to right
   , getId scaleToOne, getId flipEquation
   , getId divideCommonFactor, getId distribute
   , getId collect, getId divisionToFraction
   , getId negateBothSides
   ]

------------------------------------------------------------
-- Strategy

balanceStrategy :: LabeledStrategy (Context (WithBool (Equation Expr)))
balanceStrategy = cleanUpStrategyAfter (applyTop cleaner) $
   label "Balance equation" $
       label "Phase 1" (repeatS
           (  use collect
          <|> use distribute
          <|> use removeDivision
          <|> somewhere (use divisionToFraction)
          <|> use negateBothSides
          <|> use introTrue <|> use introFalse
           ))
   <*> label "Phase 2" (repeatS
           (  use varLeftMinus <|> use varLeftPlus
          <|> use conRightMinus <|> use conRightPlus
          <|> use introTrue <|> use introFalse
          <|> atomic (check p2 <*> (use varRightMinus <|> use varRightPlus))
          <|> atomic (check p1 <*> (use conLeftMinus <|> use conLeftPlus))
           )
       <*> try (use scaleToOne)
       <*> try (use calculate))
       -- flip sides of an equation (at most once)
   <%> try (atomic (use (check conditionVarsRHS) <*> use flipEquation))
       -- divide by a common factor (but not as final "scale-to-one" step)
   <%> many (atomic (notS (use scaleToOne) <*> use divideCommonFactor))
   -- <%> many (use introTrue <|> use introFalse)
 where
   -- move constants to left only if there are no variables on the left
   p1 = maybe False (either (const False) (hasNoVar . leftHandSide) . fromWithBool) . fromContext
   p2 ceq = fromMaybe False $ do
      wb           <- fromContext ceq
      lhs :==: rhs <- either (const Nothing) Just (fromWithBool wb)
      (x1, a, c)   <- matchLin lhs
      (x2, b, _)   <- matchLin rhs
      return (x1 == x2 && b > a && a /= 0 && c /= 0)

------------------------------------------------------------
-- Rules

calculate :: Rule (WithBool (Equation Expr))
calculate = makeRule (linbal, "calculate") $ checkForChange $
   Just . cleaner

-- factor is always positive due to lcm function
removeDivision :: Rule (Equation Expr)
removeDivision = doAfter (fmap distributeTimes) $
   describe "remove division" $
   ruleTrans (linbal, "remove-div") $
   inputWith removeDivisionArg timesRule
 where
   removeDivisionArg = makeTrans $ \(lhs :==: rhs) -> do
      xs <- match simpleSumView lhs
      ys <- match simpleSumView rhs
      -- also consider parts without variables
      -- (but at least one participant should have a variable)
      zs <- mapM getFactor (xs++ys)
      let (b, result) = foldr op (False, 1) zs
          op (b1, a1) (b2, a2) = (b1 || b2, a1 `lcm` a2)
      guard (b && result > 1)
      return (fromInteger result)

   getFactor (Negate a) = getFactor a
   getFactor expr = do
      (b, c) <- match (divView >>> second integerView) expr
      return (hasSomeVar b, c)
    `mplus` do
      r <- match rationalView expr
      return (False, denominator r)
    `mplus` do
      (r, c) <- match (timesView >>> first rationalView) expr
      return (hasSomeVar c, denominator r)
    `mplus` do
      (b, r) <- match (timesView >>> second rationalView) expr
      return (hasSomeVar b, denominator r)
    `mplus` do
      (_, ps) <- match simpleProductView expr
      guard (any (`belongsTo` rationalView) ps)
      return (False, 1)
    `mplus` do
      guard (isVariable expr)
      return (False, 1)

divisionToFraction :: Rule Expr
divisionToFraction =
   describe "turn a division into a multiplication with a fraction" $
   makeRule (linbal, "div-to-fraction") $ \expr -> do
      (a, r) <- match (divView >>> second rationalView) expr
      guard (hasSomeVar a && r /= 0)
      return (fromRational (1/r)*a)

divideCommonFactor :: Rule (Equation Expr)
divideCommonFactor = doAfter (fmap distributeDiv) $
   describe "divide by common factor" $
   ruleTrans (linbal, "smart-div") $
   inputWith (transMaybe getArg) divisionRule
 where
   getArg (lhs :==: rhs)
      | null xs                 = fail "no factor"
      | 0 `notElem` ns && n > 1 = return (fromInteger n)
      | otherwise               = fail "no factor"
    where
       xs = from simpleSumView lhs ++ from simpleSumView rhs
       ns = map getFactor xs
       n  = foldr1 gcd ns

   getFactor expr
      | hasNoVar expr = fromMaybe 1 $ match integerView expr
      | otherwise = fromMaybe 1 $ do
           (a, b) <- match timesView expr
           case (match integerView a, match integerView b) of
              (Just n, _) | hasSomeVar b -> return n
              (_, Just n) | hasSomeVar a -> return n
              _ -> Nothing

negateBothSides :: Rule (Equation Expr)
negateBothSides = describe "Remove negation on both sides of an equation" $
   rewriteRule (linbal, "negate") $ \a b ->
      (-a :==: -b) :~> (a :==: b)

varLeftMinus, varLeftPlus :: Rule (Equation Expr)
varLeftMinus = varLeft True  (linbal, "var-left-minus")
varLeftPlus  = varLeft False (linbal, "var-left-plus")

varLeft :: IsId a => Bool -> a -> Rule (Equation Expr)
varLeft useMinus rid = doAfter (fmap collectLocal) $
   ruleTrans rid $
   inputWith varLeftArg (if useMinus then minusRule else plusRule)
 where
    varLeftArg = transMaybe $ \(lhs :==: rhs) -> do
       guard (hasSomeVar lhs)
       (x, a, _) <- matchLin rhs
       guard (if useMinus then a > 0 else a < 0)
       return (fromRational (abs a) .*. x)

conRightMinus, conRightPlus :: Rule (Equation Expr)
conRightMinus = conRight True  (linbal, "con-right-minus")
conRightPlus  = conRight False (linbal, "con-right-plus")

conRight :: IsId a => Bool -> a -> Rule (Equation Expr)
conRight useMinus rid = doAfter (fmap collectLocal) $
   ruleTrans rid $
   inputWith conRightArg (if useMinus then minusRule else plusRule)
 where
    conRightArg = transMaybe $ \(lhs :==: _) -> do
       guard (hasSomeVar lhs)
       (_, _, b) <- matchLin lhs
       guard (if useMinus then b > 0 else b < 0)
       return (fromRational (abs b))

varRightMinus, varRightPlus :: Rule (Equation Expr)
varRightMinus = flipped (linbal, "var-right-minus") varLeftMinus
varRightPlus  = flipped (linbal, "var-right-plus")  varLeftPlus

conLeftMinus, conLeftPlus :: Rule (Equation Expr)
conLeftMinus = flipped (linbal, "con-left-minus") conRightMinus
conLeftPlus  = flipped (linbal, "con-left-plus")  conRightPlus

flipped :: IsId a => a -> Rule (Equation b) -> Rule (Equation b)
flipped rid = liftView flipView . changeId (const (newId rid))
 where
   flipView = makeView (Just . flipSides) flipSides

scaleToOne :: Rule (Equation Expr)
scaleToOne = doAfter (fmap distributeDiv) $
   ruleTrans (linbal, "scale-to-one") $
   inputWith scaleToOneArg divisionRule
 where
   scaleToOneArg = transMaybe $ \(lhs :==: rhs) -> f lhs rhs `mplus` f rhs lhs

   f :: Expr -> Expr -> Maybe Expr
   f expr c = do
      (_, a1, b1) <- matchLin expr
      guard (a1 /= 0 && a1 /= 1 && b1 == 0 && hasNoVar c)
      return (fromRational a1)

collect :: Rule (Equation Expr)
collect = makeRule (linbal, "collect") $
   -- don't use this rule just for cleaning up
   checkForChange (Just . fmap collectGlobal) . fmap cleanerExpr

distribute :: Rule (Equation Expr)
distribute = makeRule (linbal, "distribute") $ checkForChange $
   Just . fmap (fixpoint f)
 where
   f (a :*: (b :+: c))  = f (a*b + a*c)
   f (a :*: (b :-: c))  = f (a*b - a*c)
   f ((a :+: b) :*: c)  = f (a*c + b*c)
   f ((a :-: b) :*: c)  = f (a*c - b*c)
   f (Negate (a :+: b)) = f (-a-b)
   f (Negate (a :-: b)) = f (-a+b)
   f (Negate (Negate a)) = f a
   f (a :-: (b :+: c)) = f (a-b-c)
   f (a :-: (b :-: c)) = f (a-b+c)
   f (a :-: Negate b)  = f (a+b)
   f a = descend f a

introTrue :: Rule (WithBool (Equation Expr))
introTrue = makeRule (linbal, "intro-true") $ f . fromWithBool . cleaner
 where
   f (Right (lhs :==: rhs)) | lhs == rhs = Just true
   f _ = Nothing

introFalse :: Rule (WithBool (Equation Expr))
introFalse = makeRule (linbal, "intro-false") $ f . fromWithBool . cleaner
 where
   f (Right (lhs :==: rhs)) = do
      x <- match rationalView lhs
      y <- match rationalView rhs
      guard (x /= y)
      return false
   f _ = Nothing

-- for debugging
{-
go = printDerivation balanceExercise $ singleton $ let x=Var "x" in
   (x+2+7/2*x)/(3/2) :==: -3/2*x/4*0 -}