-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Task.Matryoshka.Recognizer (matryoshka) where

import Control.Applicative                      (many)
import Control.Arrow                            (second)
import Control.Monad                            (guard, msum, join)
import qualified Data.List                      as L
import Data.Maybe
import Recognize.Data.Approach
import Recognize.Data.Attribute hiding (Other)
import Recognize.Data.MathParserOutput
import Recognize.Data.MathParserOptions
import Recognize.Data.Diagnosis
import Recognize.Data.Math
import Recognize.Data.Step
import Recognize.Data.MathStoryProblem
import Recognize.Data.StringLexer
import Recognize.Data.StringLexerOptions
import Recognize.Expr.Functions
import Recognize.Expr.Normalform
import Recognize.Parsing.Derived
import Recognize.Parsing.Parse
import Recognize.SubExpr.SEParser
import Recognize.Recognizer
import Recognize.SubExpr.Functions hiding (isVar)
import Recognize.SubExpr.Recognizer
import Recognize.SubExpr.Symbols
import Task.Matryoshka.Assess
import Domain.Math.Data.Relation
import Domain.Math.Expr
import Ideas.Common.Id (newId)
import Ideas.Common.Rewriting (getFunction, function)
import Ideas.Utils.Uniplate (transform, para)
import Task.Network.Matryoshka
import Bayes.Evidence ( evidenceOfAbsence )

matryoshka :: MathStoryProblem
matryoshka = mathStoryProblem
   { problemId     = newId "matryoshka"
   , processInputs = id
   , analyzers     = [(newId "02", ana)]
   , inputFile     = Just "input/matryoshka.csv"
   , networkFile   = Just "networks/Matryoshka.xdsl"
   , singleNetwork = network
   }
 where
   ana = analyzer
      { lexer      = stringLexer stringLexerOptions { variableWhitelist = ["cm"] }
      , parser     = mathParser mathParserOptions {multByConcatenation = False} . stringLexerOutput
      , recognizer = myrecognize
      , collector  = evidenceOfAbsence ans1 False . assess'
      }

   myrecognize mpo =
      let mpo2 = changeInequalities $ mathParserOutput $ fixPercentages $ removeUnit mpo
      in seRecognizer pDiagnosis mpo2

-- Somewhat of a hack, since the lexer/parser should correctly handle whitelisting of variables
-- Should be fixed once pilots are finished
removeUnit :: MathParserOutput -> MathParserOutput
removeUnit (MathParserOutput mpo che) = MathParserOutput (map math mpo) che
  where
    math (M t eth) = M t $ fmap (transform expr) eth
    expr e = case e of
      (a :*: Var "cm") -> a
      _ -> e

-- If a student at least once writes a*b% and we beyond that encounter more
-- cases of '*b', but this time without the percentage symbol then this it is extremely likely
-- for it to be forgotten. Therefore we will add this ourselves now.
fixPercentages :: MathParserOutput -> MathParserOutput
fixPercentages (MathParserOutput mpo che) = MathParserOutput (modMaths (findPercentage mpo) mpo) che
  where
    findPercentage [] = Nothing
    findPercentage (m:ms) = msum [findPercentage' m, findPercentage ms]
    findPercentage' (M _ eth) = join $ either (const Nothing) Just $ fmap (para (\e' rs ->
      let mb = case e' of
                (_ :*: b :/: 100) -> Just b
                _ -> Nothing
      in msum (mb : rs))) eth
    modMaths mp = map (modMath mp)
    modMath mp (M t eth) = M t $ fmap (modExpr mp) eth
    modExpr Nothing e = e
    modExpr mp@(Just p) e
      | hasExpr 100 e = e
      | otherwise = case (e, getFunction e) of
          (a :*: b,_)
            | b == p || a == p -> a :*: b :/: 100
            | otherwise -> modExpr mp a :*: modExpr mp b
          (_,Just (s,xs)) -> function s (map (modExpr mp) xs)
          _ -> e

pDiagnosis :: SEParser Diagnosis
pDiagnosis = do
  (appr, st) <- mostCommonStrategy <$> pSteps
  let sd = newDiagnosis appr st
      -- exprs = rights (map (getResult . getMath) st)
      -- exprAsString = concatMap show exprs
  return sd

-- In case of the Numerical approach, filter steps of Numerical1 or Numerical2 strategy depending on
-- which has more steps
mostCommonStrategy :: (Approach, [Step]) -> (Approach, [Step])
mostCommonStrategy (Numerical, xs)
  | length n2Steps `div` 2 >= length n1Steps = (Other "Numerical2", xs L.\\ n1Steps)
  | otherwise = (Numerical, xs L.\\ n2Steps)
  where
    hasN2 x = any (\y -> Label "N2a" == y || Label "N2b" == y) $ snd $ getValue x
    hasN1 x = elem (Label "N1") $ snd $ getValue x
    n2Steps = filter hasN2 xs
    n1Steps = filter hasN1 xs
mostCommonStrategy s = s

-- | See if we can predetermine which approach was taken
stratHeur :: SEParser (Maybe Approach)
stratHeur =
  withInput $ \xs -> do
    let es = mapMaybe getExpr xs
        hasDef = isFunctionCall
        hasPower x = case getFunction x of
          Nothing -> False
          Just (s,ys) -> isPowerSymbol s || any hasPower ys
    if any (\x -> hasDef x || hasPower x) es
      then Just Algebraic
      else if all (\e -> any (isSubExprOf e) es) [32,24,18]
        then Just Numerical
        else Nothing

pSteps :: SEParser (Approach, [Step])
pSteps = do
  mapp <- stratHeur
  -- If we know the strategy then call the corresponding strategy parser
  -- otherwise try both
  (ap,st,sk) <- case mapp of
    Just Algebraic -> (\(st,sk) -> (Algebraic,st,sk)) <$> pStepsAlgebraic []
    Just Numerical -> (\(_,st,sk) -> (Numerical,st,sk)) <$> pStepsArith [] 32
    Nothing -> choice'
                [ do
                  (_,st,sk) <- pStepsArith [] 32
                  guard (length st >= 1) --  Must have done at least one step
                  return (Numerical,st,sk)
                , do
                  (st,sk) <- pStepsAlgebraic []
                  guard (length st >= 1)
                  return (Algebraic,st,sk)
                , return (NoApproach,[],[])
                ]
  -- Attempt to parse the final answer separately
  fa_st <- pFinalAnswer sk
  let ss = st ++ maybeToList fa_st
  guard (length ss >= 1)
  return (ap,ss)

pStepsAlgebraic :: [Math] -> SEParser ([Step],[Math])
pStepsAlgebraic _ = (\(_, x, y) -> (x, y)) <$> pRepeat (pFormula |> pFormulaLin)

-- | Formula definition: h(x) = 32*0.75^x
pFormula :: SEParser (Expr, [Step])
pFormula = do
  pLog "pFormula"
  meq <- option (satisfyEq isVar isNat)
  let (n,at) = case meq of
        Nothing -> (newMagicNat,Nothing)
        Just (_ :==: y) -> (y, Just $ LabelE "D" y)
  (e,ss) <- pMatchSubSteps (stratBNatStep n)
  let ss' = maybe ss (`addAttributeToFront` ss) at
  return (e,ss')
  where
    stratBNatStep e = lt "n" e $ \n ->
      lbl "F" (32 * (
      (0.75 ** stop (lblE "N" (n-1) (n-1) <?> lblE "N" n n))
      <!> (0.75 * stop (lblE "N" (n-1) (n-1) <?> lblE "N" n n))
      ))

-- | Filled in formula: 32*0.75^5
pFormulaLin :: SEParser (Expr, [Step])
pFormulaLin = do
  pLog "pFormulaLin"
  e <- peek >>= getExpr
  _ <- maybeToParse $ getVarS e
  _ <- few skip
  (e',ss) <- pMatchSubSteps stratBVarStep
  (_,ss',_) <- pInOrder
    [ \_ -> do
      -- x = 3 32*0.75^x
      (_ :==: y) <- satisfyEq isVar (\ex -> isNat ex || isNumber ex)
      let ss' = addAttributeToFront (LabelE "L" y) ss
      pLog ("pFormulaLIn: " ++ show ss')
      return (y, ss')
    ]
  return (e', ss')
  where
    stratBVarStep = lt "v" newMagicVar $ \v ->
      lbl "Def" (32 * (
      (0.75 ** stop ((v-1) <!> v))
      <!> (0.75 * stop ((v-1) <!> v))
      )) <&> newMagicNat

-- | Numerical strategy
--
-- Starts with 32 and either continues to apply *0.75 (A steps)
--
-- or *.25 followed by a subtraction (S steps)
pStepsArith :: [Math] -> Expr -> SEParser (Expr, [Step],[Math])
pStepsArith xs e = do
  mb <- option (pStratS e |> pStratA stratAStep e)
  me <- safePeek
  pLog ("pStepsArith: " ++ show mb ++ " | " ++ show me)
  case mb of
    Just (es, sts) -> (\(a,b,c) -> (a,sts++b,c)) <$> pStepsArith xs es
    Nothing -> choice'
        [ do
          s <- skip
          pStepsArith (s:xs) e
        , return (e,[],xs)
        ]
  where
    stratAStep ex = lbl "N1" $ ex * (0.75 <!> 0.25)

pStratA :: (Expr -> Expr) -> Expr -> SEParser (Expr, [Step])
pStratA strat e = do
  modify $ \st -> st { optGrow = True, growF = strat }
  modify $ \st -> st { chainedEquations = True }
  resetAfter $ pMatchSubSteps (strat e)


pStratS :: Expr -> SEParser (Expr, [Step])
pStratS e =
  choice
    [ do
      (e', ss) <- pMatchSubSteps (stratSStep_f_a e)
      _ <- few skip
      second (ss++) <$> pMatchSubSteps (stratSStep_f_b e e')
    , pMatchSubSteps (stratSStep_f_s e)
    , pMatchSubSteps (stratSStep_f_c e)
    ]
  where
    stratSStep_f_s ex = lbl "N2b" $ sub (ex - nf (ex * 0.25)) <!> sub (ex - nf (ex * 0.75)) <!> sub (ex - 8)
    stratSStep_f_a ex = lbl "N2a" (ex * (0.25 <!> 0.75))
    stratSStep_f_b ex d = lbl "N2b" $ sub (ex - d)
    stratSStep_f_c ex = lbl "N2b" $ sub (ex - 8)



pFinalAnswer :: [Math] -> SEParser (Maybe Step)
pFinalAnswer skipped = do
  pLog "pFinalAnswer"
  rest <- many skip
  let me = closestInList (filter (\n -> isNat n && n >= 2 && n <= 10) $ mapMaybe getExpr (skipped ++ rest)) 6
  case me of
    Nothing -> pLog "empty" >> return Nothing
    Just fe -> return $ Just $ makeFAStep fe


addAttributeToFront :: Attribute -> [Step] -> [Step]
addAttributeToFront at ss = reverse $
  case reverse ss of
    [] -> []
    (x:xs) -> addAttribute at x : xs

changeInequalities :: [Math] -> [Math]
changeInequalities = map $ \m -> fromMaybe m $ do
   r <- getRelation m
   guard $ relationType r `notElem` [EqualTo, NotEqualTo]
   return $ M (getString m) (Right $ leftHandSide r)