-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- Assessment for the magictrick exercise
--
-----------------------------------------------------------------------------

module Task.MagicTrick.Assess where

import Data.Monoid
import qualified Data.List.NonEmpty as N
import qualified Data.Semigroup as SG
import Domain.Math.Expr
import Ideas.Common.Id hiding ((<>))
import Recognize.Data.Approach
import Recognize.Data.Attribute
import Bayes.Evidence
import Recognize.Data.Op
import Recognize.Data.RuleId
import Recognize.Data.Diagnosis as S
import Recognize.Data.Step
import Recognize.Model.Assess
import Recognize.Model.Connectives
import Recognize.Model.Constraint
import Recognize.Model.EvidenceBuilder
import Recognize.Expr.Symbols
import Task.Network.MagicTrick
import Bayes.Network

assess' :: Diagnosis -> Evidence
assess' sd =
  --stringNode (apprtostring appr) ans1Strat <> --check which approach has been used
  generateEvidence (`buildStepsEvidence` me) appr attrs
  where
    -- sd = firstDiagnosis md
    attrs = map (snd . getValue) $ steps sd
    me = payload sd --Get the magic expression
    appr = approach sd

-- Unnecessary to verify arithmetic strategy constraints when the category is algebraic and vice versa.
buildStepsEvidence :: Approach -> Maybe Expr -> EvBuilder ()
buildStepsEvidence Algebraic (Just me) = stepsA1Builder me SG.<> stepsA2Builder me
buildStepsEvidence Numerical (Just me) = stepsN1Builder me SG.<> stepsN2Builder me
buildStepsEvidence _ Nothing           = return ()

add8C :: Constraint EvBuilder [Attribute]
add8C = exists1 (Expand $ Add 8)
mul3C :: Constraint EvBuilder [Attribute]
mul3C = exists1 (Expand $ Mul 3)
sub4C :: Constraint EvBuilder [Attribute]
sub4C = exists1 (Expand $ Sub 4)
addXC :: Expr -> Constraint EvBuilder [Attribute]
addXC x = exists1 (Expand $ Add x)
div4C :: Constraint EvBuilder [Attribute]
div4C = exists1 (Expand $ Div 4)
add2C :: Constraint EvBuilder [Attribute]
add2C = exists1 (Expand $ Add 2)
subXC :: Expr -> Constraint EvBuilder [Attribute]
subXC x = exists1 (Expand $ Sub x)

fullyExpanded :: Expr -> Constraint EvBuilder [Attribute]
fullyExpanded x = mconcat [add8C, mul3C, sub4C, addXC x, div4C, add2C, subXC x]

stepsA1Builder :: Expr -> EvBuilder ()
stepsA1Builder me = do
  -- For all steps in the general strategy the formula must first be fully expanded
  let fullC = fullyExpanded me
  giveNodeAndCollect ans1Strat1Step2 (fullC <> (fullC ==> failOnAnyMistake))

  -- We introduce a variable that will match to anything
  a1_2 <- newVar "a1_2"
  giveNodeAndCollect ans1Strat1Step3 (fullC ==> implication
              -- The only difference between the hypothesis and the conclusion is @a1_2@ and @24@. Since @a1_2@ will match to anything
              -- if it matches to something that is not @24@ then a failure is generated. Otherwise, either we get a success or unknown.
              [ ARule Distr_Times (3 * (me + 8) N.:| []) (3*me + a1_2) -- must be nat
              ]
              [ ARule Distr_Times (3 * (me + 8) N.:| []) (3*me + 24) ])

  a1_3 <- newVar "a1_3"
  giveNodeAndCollect ans1Strat1Step4 (fullC ==> implication
              -- By now we should have an actual value for @a1_2@, but we repeat the same process that we used in the above constraint
              [ ARule Collect_Num (a1_2 N.:| [-4]) a1_3
              , ARule Collect_Var (3*me N.:| [me]) (4*me)
              ]
              -- normalform will simplify @a1_2 - 4@ within the comparison function
              [ ARule Collect_Num (a1_2 N.:| [-4]) (normalform $ a1_2 - 4)
              , ARule Collect_Var (3*me N.:| [me]) (4*me) ])

  a1_4 <- newVar "a1_4"
  giveNodeAndCollect ans1Strat1Step5 $ fullC ==> implication
              [ ARule Distr_Division (((4*me + a1_3) / 4) N.:| []) ((4*me)/4 + a1_3/4)
              , ARule Division (4*me/4 N.:| []) me
              , ARule Division (a1_3/4 N.:| []) a1_4
              ]
              [ ARule Distr_Division (((4*me + a1_3) / 4) N.:| []) ((4*me)/4 + a1_3/4)
              , ARule Division (4*me/4 N.:| []) me
              , ARule Division (a1_3/4 N.:| []) (normalform $ a1_3 / 4)]

  giveNodeAndCollect ans1Strat1Step6 $ fullC ==> implication
              [ ARule Collect_Var (me N.:| [- me]) 0
              , ARule Collect_Num (a1_4 N.:| [2]) (normalform $ a1_4 + 2)
              ]
              []

  giveNodeAndCollect ans1 $ fullC ==> implication
              [ ARule Collect_Var (me N.:| [- me]) 0
              , ARule Collect_Num (a1_4 N.:| [2]) (normalform $ a1_4 + 2)
              ]
              []

stepsA2Builder :: Expr -> EvBuilder ()
stepsA2Builder me = do
  giveNodeAndCollect ans1Strat2Step1 add8C

  a2_2 <- newVar "a2_2"
  -- For stepwise strategies we only require expansion of the formula up to a certain point
  -- If we could not verify constraint a2_2 then it must mean @a2_2@ still refers to some wildcard. With collectDefault @a2_2@ will be a2_signed 24
  -- if the result of constraint @a2_2@ is unknown
  giveNodeAndCollectDefault ans1Strat2Step2 (add8C ==> mul3C ==> (implication
              [ ARule Distr_Times ((me + 8)*3 N.:| []) (3*me + a2_2) ]
              [ ARule Distr_Times ((me + 8)*3 N.:| []) (3*me + 24)
              ]
              <?>> failOnAnyMistake))
              a2_2 24

  -- With collectDefault @a2_3@ will be assigned @a2_2 - 4@ if the result of constraint @a2_3@ is unknown.
  a2_3 <- newVar "a2_3"
  giveNodeAndCollectDefault ans1Strat2Step3 (sub4C ==> (implication
              [ ARule Collect_Num (a2_2 N.:| [-4]) a2_3 ]
              [ ARule Collect_Num (a2_2 N.:| [-4]) (normalform $ a2_2 - 4) ]
              <?>> failOnAnyMistake))
              a2_3 (normalform $ a2_2 - 4)

  giveNodeAndCollect ans1Strat2Step4 $ addXC me ==> (implication
                [ ARule Collect_Var (3*me N.:| [me]) (4*me)]
                [ ARule Collect_Var (3*me N.:| [me]) (4*me) ]
                <?>> failOnAnyMistake)

  -- With collectDefault @a2_5@ will be assigned @a2_3 / 4@ if the result of constraint @a2_5@ is unknown.
  a2_5 <- newVar "a2_5"
  giveNodeAndCollectDefault ans1Strat2Step5 (div4C ==> (implication
              [ ARule Distr_Division (((4*me + a2_3) / 4) N.:| []) ((4*me)/4 + a2_3/4)
              , ARule Division (4*me/4 N.:| []) me
              , ARule Division (a2_3/4 N.:| []) a2_5
              ]
              [ ARule Distr_Division (((4*me + a2_3) / 4) N.:| []) ((4*me)/4 + a2_3/4)
              , ARule Division (4*me/4 N.:| []) me
              , ARule Division (a2_3/4 N.:| []) (normalform $ a2_3 / 4)]
              <?>> failOnAnyMistake))
              a2_5 (normalform $ a2_3 / 4)

  a2_6 <- newVar "a2_6"
  giveNodeAndCollectDefault ans1Strat2Step6 (add2C ==> (implication
              [ ARule Collect_Num (a2_5 N.:| [2]) a2_6 ]
              [ ARule Collect_Num (a2_5 N.:| [2]) (normalform $ a2_5 + 2) ]
              <?>> failOnAnyMistake))
              a2_6 (normalform $ a2_5 + 2)

  a2_7 <- newVar "a2_7"
  giveNodeAndCollect ans1Strat2Step7 $ subXC me ==> (implication
              [ ARule Collect_Var (me N.:| [-me]) 0 ]
              []
              <?>> failOnAnyMistake)

  giveNodeAndCollect ans1 $ subXC me ==> (implication
              [ ARule Collect_Var (me N.:| [-me]) 0 ]
              []
              <?>> failOnAnyMistake)

stepsN1Builder :: Expr -> EvBuilder ()
stepsN1Builder me = do
  let fullC = fullyExpanded me
  giveNodeAndCollect ans1Strat3Step1 $ fullC ==> failOnAnyMistake

  --Create an empty node to hold the intermediate steps

  let empty = Node "empty" "" [("Correct", True),("Incorrect", False)] [] (CPT [0.5, 0.5])

  n1_2 <- newVar "n1_2"
  giveNodeAndCollect empty $ fullC ==> (implication
              [ ARule Collect_Num (me N.:| [8]) n1_2 ]
              [ ARule Collect_Num (me N.:| [8]) (normalform $ me + 8) ]
              <?>> failOnAnyMistake)

  n1_3 <- newVar "n1_3"
  giveNodeAndCollect empty $ fullC ==> (implication
              [ ARule Times (n1_2 N.:| [3]) n1_3 ]
              [ ARule Times (n1_2 N.:| [3]) (normalform $ n1_2 * 3) ]
              <?>> failOnAnyMistake)

  n1_4 <- newVar "n1_4"
  giveNodeAndCollect empty $ fullC ==> (implication
              [ ARule Collect_Num (n1_3 N.:| [-4, me]) n1_4 ]
              [ ARule Collect_Num (n1_3 N.:| [-4, me]) (normalform $ n1_3 - 4 + me) ]
              <?>> failOnAnyMistake)

  n1_5 <- newVar "n1_5"
  giveNodeAndCollect empty $ fullC ==> (implication
              [ ARule Division (n1_4 N.:| [4]) n1_5 ]
              [ ARule Division (n1_4 N.:| [4]) (normalform $ n1_4 / 4) ]
              <?>> failOnAnyMistake)

  n1_6 <- newVar "n1_6"
  giveNodeAndCollect ans1Strat3Step2 $ fullC ==> (implication
              [ ARule Collect_Num (n1_5 N.:| [2, me]) n1_6 ]
              [ ARule Collect_Num (n1_5 N.:| [2, me]) (normalform $ n1_5 + 2 - me) ]
              <?>> failOnAnyMistake)

  giveNodeAndCollect ans1 $ fullC ==> (implication
              [ ARule Collect_Num (n1_5 N.:| [2, me]) n1_6 ]
              [ ARule Collect_Num (n1_5 N.:| [2, me]) (normalform $ n1_5 + 2 - me) ]
              <?>> failOnAnyMistake)

stepsN2Builder :: Expr -> EvBuilder ()
stepsN2Builder me = do
  n2_1 <- newVar "n2_1"
  let cn2_1a = add8C
      cn2_1b = implication
                [ ARule Collect_Num (me N.:| [8]) n2_1 ]
                [ ARule Collect_Num (me N.:| [8]) (normalform $ me + 8) ]
              <?>> failOnAnyMistake
  if me == 0
    then setValueOf n2_1 8 >> giveNodeAndCollect ans1Strat4Step1 cn2_1a
    else giveNodeAndCollectDefault ans1Strat4Step1 (add8C ==> cn2_1b) n2_1 (normalform $ me + 8)

  n2_2 <- newVar "n2_2"
  giveNodeAndCollectDefault ans1Strat4Step2 (mul3C ==> (implication
              [ ARule Times (n2_1 * 3 N.:| []) n2_2 ]
              [ ARule Times (n2_1 * 3 N.:| []) (normalform $ n2_1 * 3) ]
              <?>> failOnAnyMistake))
              n2_2 (normalform $ n2_1 * 3)

  n2_3 <- newVar "n2_3"
  giveNodeAndCollectDefault ans1Strat4Step3 (sub4C ==> (implication
              [ ARule Collect_Num (n2_2 N.:| [-4]) n2_3 ]
              [ ARule Collect_Num (n2_2 N.:| [-4]) (normalform $ n2_2 - 4) ]
              <?>> failOnAnyMistake))
              n2_3 (normalform $ n2_2 - 4)

  n2_4 <- newVar "n2_4"
  giveNodeAndCollectDefault ans1Strat4Step4 (addXC me ==> (implication
              [ ARule Collect_Num (n2_3 N.:| [me]) n2_4 ]
              [ ARule Collect_Num (n2_3 N.:| [me]) (normalform $ n2_3 + me) ]
              <?>> failOnAnyMistake))
              n2_4 (normalform $ n2_3 + me)

  n2_5 <- newVar "n2_5"
  giveNodeAndCollectDefault ans1Strat4Step5 (div4C ==> (implication
              [ ARule Division (n2_4 / 4 N.:| []) n2_5 ]
              [ ARule Division (n2_4 / 4 N.:| []) (normalform $ n2_4 / 4) ]
              <?>> failOnAnyMistake))
              n2_5 (normalform $ n2_4 / 4)

  n2_6 <- newVar "n2_6"
  giveNodeAndCollectDefault ans1Strat4Step6 (add2C ==> (implication
              [ ARule Collect_Num (n2_5 N.:| [2]) n2_6 ]
              [ ARule Collect_Num (n2_5 N.:| [2]) (normalform $ n2_5 + 2) ]
              <?>> failOnAnyMistake))
              n2_6 (normalform $ n2_5 + 2)

  n2_7 <- newVar "n2_7"
  giveNodeAndCollect ans1Strat4Step7 $ subXC me ==> (implication
              [ ARule Collect_Num (n2_6 N.:| [-me]) n2_7
              ]
              [ ARule Collect_Num (n2_6 N.:| [-me]) (normalform $ n2_6 - me) ]
              <?>> failOnAnyMistake)

  giveNodeAndCollect ans1 $ subXC me ==> (implication
              [ ARule Collect_Num (n2_6 N.:| [-me]) n2_7
              ]
              [ ARule Collect_Num (n2_6 N.:| [-me]) (normalform $ n2_6 - me) ]
              <?>> failOnAnyMistake)