module Database.Priors
( calculatePriors
, Priors(Priors)
, Prior(Prior)
) where
import Data.Function ( on )
import Data.List ( unionBy )
import Data.Semigroup ( Semigroup, (<>) )
import Text.PrettyPrint.Leijen ( pretty, (<+>) )
import qualified Data.Map as M
import qualified Text.PrettyPrint.Leijen as PP
import Recognize.Data.MathStoryProblem ( Task(Task), singleNetwork )
import Main.Tasks ( findTaskFuzzy, taskNetwork )
import Util.Pretty ( prettyMap )
import Database.Data ( StudentID, TaskID, NodeID )
import Bayes.Evidence ( Evidence, hardEvidence )
import Bayes.Network ( nodeId, nodes, state2label )
newtype Priors = Priors (M.Map (TaskID, NodeID) Prior)
instance Semigroup Prior where
(Prior x) <> (Prior y) = Prior $ M.unionWith (+) x y
instance Monoid Prior where
mempty = Prior mempty
mappend = (<>)
instance PP.Pretty Prior where
pretty (Prior m) =
PP.vsep .
map (\(k,v) -> case k of
Nothing -> pretty "<No state>:" <+> pretty v
Just i -> pretty i PP.<> pretty ':' <+> pretty v) .
M.toList $ m
newtype Prior = Prior (M.Map (Maybe String) Int)
instance Semigroup Priors where
(Priors x) <> (Priors y) = Priors $ M.unionWith (<>) x y
instance Monoid Priors where
mempty = Priors mempty
mappend = (<>)
instance PP.Pretty Priors where
pretty (Priors m) = prettyMap $ M.mapKeysWith (<>) snd m
calculatePriors :: M.Map (StudentID, TaskID) Evidence -> Priors
calculatePriors =
M.foldr (<>) mempty .
M.mapWithKey (\(_, tID) ev -> makePriors tID ev)
where
taskNodes :: TaskID -> [NodeID]
taskNodes = maybe [] (map nodeId . nodes . (\(Task t) -> singleNetwork t)) . findTaskFuzzy
makePriors :: TaskID -> Evidence -> Priors
makePriors tID ev =
let notFound = map (\x -> (x, Nothing)) $ taskNodes tID
found = map (\(x,y) -> (x, Just . state2label taskNetwork x $ y)) (hardEvidence ev)
nodeStates = unionBy ((==) `on` fst) found notFound
in foldr (addNodeState tID) mempty nodeStates
singleton :: Maybe String -> Prior
singleton state = Prior $ M.singleton state 1
addNodeState :: TaskID -> (NodeID, Maybe String) -> Priors -> Priors
addNodeState tID (nodeID, state) (Priors m) = Priors $ M.insertWith (<>) (tID, nodeID) (singleton state) m