-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- Counts occurrences of evidence in the database.
-- This module allows us to count how often nodes are in a particular state.
--
-----------------------------------------------------------------------------

module Database.Priors
   ( calculatePriors
   , Priors(Priors)
   , Prior(Prior)
   ) where

import Data.Function ( on )
import Data.List ( unionBy )
import Data.Semigroup ( Semigroup, (<>) )
import Text.PrettyPrint.Leijen ( pretty, (<+>) )
import qualified Data.Map as M
import qualified Text.PrettyPrint.Leijen as PP

import Recognize.Data.MathStoryProblem ( Task(Task), singleNetwork )
import Main.Tasks ( findTaskFuzzy, taskNetwork )
import Util.Pretty ( prettyMap )
import Database.Data ( StudentID, TaskID, NodeID )
import Bayes.Evidence ( Evidence, hardEvidence )
import Bayes.Network ( nodeId, nodes, state2label )

-------------------------------------------------------------------------------

-- | Associates nodes with their priors.
newtype Priors = Priors (M.Map (TaskID, NodeID) Prior)

instance Semigroup Prior where
   (Prior x) <> (Prior y) = Prior $ M.unionWith (+) x y

instance Monoid Prior where
   mempty  = Prior mempty
   mappend = (<>)

instance PP.Pretty Prior where
   pretty (Prior m) =
      PP.vsep .
      map (\(k,v) -> case k of
         Nothing -> pretty "<No state>:" <+> pretty v
         Just i -> pretty i PP.<> pretty ':' <+> pretty v) .
      M.toList $ m

-- | Associates state indices (or their absence) with a counter.
newtype Prior = Prior (M.Map (Maybe String) Int)

instance Semigroup Priors where
   (Priors x) <> (Priors y) = Priors $ M.unionWith (<>) x y

instance Monoid Priors where
   mempty  = Priors mempty
   mappend = (<>)

instance PP.Pretty Priors where
   pretty (Priors m) = prettyMap $ M.mapKeysWith (<>) snd m


-- | Calculate the priors, that is: for every node, count how often each state
-- occurs in the observed data.
calculatePriors :: M.Map (StudentID, TaskID) Evidence -> Priors
calculatePriors =
   M.foldr (<>) mempty .
   M.mapWithKey (\(_, tID) ev -> makePriors tID ev)

   where

   -- | Obtain available nodes for each task.
   taskNodes :: TaskID -> [NodeID]
   taskNodes = maybe [] (map nodeId . nodes . (\(Task t) -> singleNetwork t)) . findTaskFuzzy

   -- | We make the priors by setting the counters for found evidence states to 1, and setting the counters for available nodes without evidence whatsoever to 1 also.
   makePriors :: TaskID -> Evidence -> Priors
   makePriors tID ev =
      let notFound   = map (\x -> (x, Nothing)) $ taskNodes tID
          found      = map (\(x,y) -> (x, Just . state2label taskNetwork x $ y)) (hardEvidence ev)
          nodeStates = unionBy ((==) `on` fst) found notFound
      in foldr (addNodeState tID) mempty nodeStates

   singleton :: Maybe String -> Prior
   singleton state = Prior $ M.singleton state 1

   addNodeState :: TaskID -> (NodeID, Maybe String) -> Priors -> Priors
   addNodeState tID (nodeID, state) (Priors m) = Priors $ M.insertWith (<>) (tID, nodeID) (singleton state) m