-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.NetworkReader (readNetwork) where

import Control.Monad
import Numeric
import Ideas.Text.XML hiding (name)
import Ideas.Utils.Prelude
import Bayes.Network
import Bayes.Probability
import Bayes.SVG

readNetwork :: FilePath -> IO (Layout, Network ())
readNetwork = parseXMLFile >=> xmlToNetwork

data GenieNode = GenieNode
   { genieNodeId   :: String
   , genieName     :: String
   , geniePosition :: Point
   }

xmlToNetwork :: Monad m => XML -> m (Layout, Network ())
xmlToNetwork xml = do
   networkId <- findAttribute "id" xml
   ns   <- findChild "nodes" xml >>= xmlToNodes
   exts <- findChild "extensions" xml >>= xmlToExtensions
   let addLabel n =
          case filter ((== nodeId n) . genieNodeId) exts of
             gn:_ -> n {label = genieName gn}
             _ -> n
   return ( map (\gn -> (genieNodeId gn, geniePosition gn)) exts
          , mapNodes addLabel (makeNetwork networkId ns)
          )

data NodeType = DefNormal | DefNoisyMax | DefNoisyAdder

xmlToNodes :: Monad m => XML -> m [Node ()]
xmlToNodes xml = (\xs ys zs -> xs ++ ys ++ zs)
   <$> mapM (xmlToNode DefNormal)     (findChildren "cpt" xml)
   <*> mapM (xmlToNode DefNoisyMax)   (findChildren "noisymax" xml)
   <*> mapM (xmlToNode DefNoisyAdder) (findChildren "noisyadder" xml)

xmlToNode :: Monad m => NodeType -> XML -> m (Node ())
xmlToNode nodeTp xml = do
   nId <- findAttribute "id" xml
   xs  <- mapM xmlToState (findChildren "state" xml)
   let pars = maybe [] xmlToParents (findChild "parents" xml)
   def <- case nodeTp of
             DefNormal -> CPT <$> (findChild "probabilities" xml >>= xmlToProbabilities)
             DefNoisyMax -> do
                str <- findChild "strengths"  xml >>= xmlToInts
                pms <- findChild "parameters" xml >>= xmlToProbabilities
                return $ NoisyMax str pms
             DefNoisyAdder -> do
                dst <- findChild "dstates" xml    >>= xmlToInts
                ws  <- findChild "weights" xml    >>= xmlToDoubles
                pms <- findChild "parameters" xml >>= xmlToProbabilities
                return $ NoisyAdder dst ws pms
   return $ Node nId "" xs pars def

xmlToState :: Monad m => XML -> m (String, ())
xmlToState xml = do
   stateId <- findAttribute "id" xml
   return (stateId, ())

xmlToParents :: XML -> [String]
xmlToParents xml = words $ getData xml

xmlToProbabilities :: Monad m => XML -> m [Probability]
xmlToProbabilities xml =
   mapM (fmap fromRational . readRational) $ words $ getData xml

xmlToInts :: Monad m => XML -> m [Int]
xmlToInts = mapM readM . words . getData

xmlToDoubles :: Monad m => XML -> m [Double]
xmlToDoubles = mapM readM . words . getData

readRational :: Monad m => String -> m Rational
readRational s =
   case readFloat s of
      (r, _):_ -> return r
      _        -> fail $ "readRational " ++ s

xmlToExtensions :: Monad m => XML -> m [GenieNode]
xmlToExtensions xml =
   concat <$> mapM xmlToGenie (findChildren "genie" xml)

xmlToGenie :: Monad m => XML -> m [GenieNode]
xmlToGenie xml =
   mapM xmlToGenieNode (findChildren "node" xml)

xmlToGenieNode :: Monad m => XML -> m GenieNode
xmlToGenieNode xml = GenieNode
   <$> findAttribute "id" xml
   <*> getData <$> findChild "name" xml
   <*> (findChild "position" xml >>= xmlToPosition)

xmlToPosition :: Monad m => XML -> m Point
xmlToPosition xml = f <$> (mapM readM (words (getData xml)))
 where
   f [x1, y1, _, _] = pt x1 y1 -- ignore x2 and y2: use fixed width and height
   f _ = pt 0 0