-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.Evidence
   ( Evidence, getIndex, set, setN, setId, clear
   , virtual, addVirtualNodes, fromEvidenceTp
   , EvidenceType(..), virtualId, filterEvidence
   , evidenceMap
   , evStates
   , hardEvidence
        -- * Evidence
   , (.==)
   , (.=~)
   , newEvidence, evidenceIndex, fromEvidence, allProbabilities
   , evidenceOfAbsence
   , nodeNotSet
   , nodeIsSet
   , nodeSetTo
   , noVirtuals
   , isEmpty, getNames
     -- * Probabilities
   , Probabilities, fromProbabilities, makeProbabilities, findProbability
   , probabilitiesFor
   , renameEvidence
   ) where

import Bayes.Probability
import Bayes.Network
import qualified Data.Map as M
import Control.Monad (forM)
import Control.Arrow
import Data.List
import Data.Maybe
import Data.Semigroup
import Ideas.Utils.Parsing
import Ideas.Utils.Prelude (readM)

import qualified Ideas.Text.XML as XML

-- | Evidence is associated to some node in a network by a string. It carries
--   either an index of the state that we say we have evidence for (hard
--   evidence) or a probability for each state of the node (virtual evidence)
newtype Evidence = E { evidenceMap :: M.Map String EvidenceType }
 deriving (Eq, Ord)

instance Show Evidence where
   show = intercalate ", " . map f . fromEvidenceTp
    where
      f (s, et) = s ++ "=" ++ show et

instance Read Evidence where
   readsPrec _ = either (const []) (\a -> [(a, "")]) . parseSimple pEvidence

pEvidence :: Parser Evidence
pEvidence = mconcat <$> sepBy pItem (char ',' <* spaces) <* spaces

pItem :: Parser Evidence
pItem = flip ($) <$> identifier <* char '=' <*> pType

pType :: Parser (String -> Evidence)
pType =  flip evidenceIndex <$ char '#' <*> pNat
     <|> flip virtualId     <$ char '[' <*> sepBy pPair (char ',' <* spaces) <* char ']'

pPair :: Parser (String, Probability)
pPair = (,) <$> identifier <* char ':' <*> pProbability

pProbability :: Parser Probability
pProbability = fromRational . toRational <$> pDouble <* char '%'

pDouble :: Parser Double
pDouble = (\x y -> read (x ++ "." ++ y) / 100) <$> many1 digit <* char '.' <*> many1 digit

pNat :: Parser Int
pNat = read <$> many1 digit

identifier :: Parser String
identifier = many (alphaNum <|> oneOf "-_")


instance Semigroup Evidence where
   E m1 <> E m2 = E (m1 <> m2)

data EvidenceType = Index Int | Virtual [(String, Probability)]
 deriving (Eq, Ord)

instance Show EvidenceType where
   show (Index i)    = '#' : show i
   show (Virtual xs) = '[' : intercalate ", " (map f xs) ++ "]"
    where
      f (s, p) = s ++ ":" ++ show p

instance Monoid Evidence where
   mempty = E mempty
   E m1 `mappend` E m2 = E (m1 `mappend` m2)

fromEvidenceTp :: Evidence -> [(String, EvidenceType)]
fromEvidenceTp = M.toList . evidenceMap

getIndex :: Node a -> Evidence -> Maybe Int
getIndex n (E m) =
   case M.lookup (nodeId n) m of
      Just (Index i) -> Just i
      _              -> Nothing

virtual :: Node a -> [Probability] -> Evidence
virtual n ps = virtualId (nodeId n) (zip (map fst (states n)) ps)

virtualId :: String -> [(String, Probability)] -> Evidence
virtualId s ps = -- setId ('#' : s) 0 $
   E (M.singleton s (Virtual ps))

set :: Node a -> Int -> Evidence -> Evidence
set = setId . nodeId

setId :: String -> Int -> Evidence -> Evidence
setId s i (E m) = E $ M.insert s (Index i) m

setN :: [Node a] -> [Int] -> Evidence -> Evidence
setN ns as ev = foldr (uncurry set) ev (zip ns as)

clear :: Node a -> Evidence -> Evidence
clear n (E m) = E $ M.delete (nodeId n) m

filterEvidence :: (String -> Bool) -> Evidence -> Evidence
filterEvidence p (E m) = E (M.filterWithKey (const . p) m)

addVirtualNodes :: Evidence -> Network () -> Network ()
addVirtualNodes (E m) = foldr add id (M.toList m)
 where
   add (s, Virtual ps) = (. addNode (observedNode s ps))
   add _ = id

observedNode :: String -> [(String, Probability)] -> Node ()
observedNode s ps = (node ('#' : s))
   { states     = [("yes", ()), ("no", ())]
   , parentIds  = [s]
   , definition = CPT $ concatMap (\(_, x) -> [x, 1-x]) ps
   }

noVirtuals :: Evidence -> Evidence
noVirtuals (E m) = E (M.filter isIndex m)
 where
   isIndex (Index _) = True
   isIndex _ = False

--------------------------------------------------------------------------------
-- EXTRA

---------------------------------------------------------------------------
-- Evidence

-- | Evidence is associated to some node in a network by a string
--   It carries either an index of the state that we say we have evidence of
--   or a probability for each state of the node
{-
newtype Evidence = E { evidenceMap :: M.Map String (Either Int Probabilities) }

instance Show Evidence where
   show = intercalate ", " . map f . fromEvidence
    where
      f (s, et) = s ++ "=" ++ either (('#':) . show) show et

instance Semigroup Evidence where
   E m1 <> E m2 = E (m1 <> m2)

instance Monoid Evidence where
   mempty  = E M.empty
   mappend = (<>)

instance ToHTML Evidence where
   toHTML e
      | isEmpty e = mempty
      | otherwise = w3table True $ header : map f (fromEvidence e)
    where
      header   = map string ["id", "value"]
      f (s, v) = [string s, either (text . toBool) toHTML v]
      toBool = (== 0)
-}
newEvidence :: String -> Probabilities -> Evidence
newEvidence s xs = virtualId s (map (second fromRational) (fromProbabilities xs))

(.==) :: Eq a => Node a -> a -> Evidence
n .== a = evidenceIndex (nodeId n) (fromMaybe 0 (findIndex f (states n)))
 where
   f = (== a) . snd

-- | Soft evidence counterpart to .==
(.=~) :: Eq a => Node a -> [Probability] -> Evidence
(.=~) = virtual

-- | Set a particular node to a default value if it is absent.
evidenceOfAbsence :: Eq a => Node a -> a -> Evidence -> Evidence
evidenceOfAbsence n def ev = if nodeNotSet n ev then n .== def <> ev else ev

-- | Return true if a node has no state in the given evidence.
nodeNotSet :: Node a -> Evidence -> Bool
nodeNotSet n ev = isNothing $ getIndex n ev

nodeIsSet :: Node a -> Evidence -> Bool
nodeIsSet n = not . nodeNotSet n

-- | Query whether a node is in a certain hard state in the evidence.
nodeSetTo :: Eq a => Node a -> a -> Evidence -> Bool
nodeSetTo n state ev = expectedIndex == actualIndex

   where
   expectedIndex = findIndex ((== state) . snd) (states n)
   actualIndex = getIndex n ev


evidenceIndex :: String -> Int -> Evidence
evidenceIndex s i = setId s i mempty

fromEvidence :: Evidence -> [(String, Either Int Probabilities)]
fromEvidence = map (second f) . fromEvidenceTp
 where
   f (Index i)    = Left i
   f (Virtual xs) = Right (makeProbabilities $ map (second toRational) xs)

allProbabilities :: Evidence -> [(String, String, Rational)]
allProbabilities ev =
   [ (n, s, r)
   | (n, v) <- fromEvidence ev
   , (s, r) <- either (const []) fromProbabilities v
   ]

--filterEvidence :: (String -> Bool) -> Evidence -> Evidence
--filterEvidence = undefined --  p (E m) = E (M.filterWithKey (const . p) m)

isEmpty :: Evidence -> Bool
isEmpty = null . fromEvidence

getNames :: Evidence -> [String]
getNames = map fst . fromEvidence


-- | Evidence is recorded as a mapping from node names to probabilities of
-- being in a certain state. This function obtains a pairing of strings of the
-- form "<NODE>#<STATE>" to a number from 0 to 1.
evStates :: Evidence -> [(String, Double)]
evStates = map (\(node', state, prob) -> (node' ++ "#" ++ state, fromRational prob)) . allProbabilities


-- | Obtain the states of the hard evidence.
hardEvidence :: Evidence -> [(String, Int)]
hardEvidence = foldl' f [] . fromEvidence

   where
   f xs (node, Left  i)  = (node, i):xs
   f xs (node, Right ps) = xs

---------------------------------------------------------------------------
-- Probabilities

newtype Probabilities = P (M.Map String Rational)
   deriving Show

{-
instance ToHTML Probabilities where
  toHTML (P m) = table $ tr $ map (left . td . f) $ M.toList m
   where
     f (n, b) = string (n ++ ": " ++ show (fromRational b :: Double)) -}

makeProbabilities :: [(String, Rational)] -> Probabilities
makeProbabilities = P . M.fromList

fromProbabilities :: Probabilities -> [(String, Rational)]
fromProbabilities (P m) = M.toList m

findProbability :: Probabilities -> String -> Maybe Rational
findProbability (P m) s = M.lookup s m

-- | Obtain the probabilities for soft evidence at a particular node.
probabilitiesFor :: Evidence -> String -> Maybe [(String, Probability)]
probabilitiesFor ev identifier = do
   states <- M.lookup identifier $ evidenceMap ev
   case states of
      Index _ -> Nothing
      Virtual xs -> Just xs

renameEvidence :: (String -> String) -> Evidence -> Evidence
renameEvidence f (E m) = E $ M.mapKeys f m

-------------------------------------------------------------------------------


instance XML.ToXML Evidence where
   toXML = XML.makeXML "evidence"
         . mconcat
         . map (uncurry f)
         . M.toList
         . evidenceMap

      where
      f :: String -> EvidenceType -> XML.XMLBuilder
      f lbl (Index i) =
         XML.element "hard"
            [ "label" XML..=. lbl
            , XML.tag "state" ("value" XML..=. show i)
            ]
      f lbl (Virtual is) =
         XML.element "soft"
            [ "label" XML..=. lbl
            , mconcat [ XML.element "state"
                           [ "value" XML..=. k
                           , "probability" XML..=. show v
                           ] | (k, v) <- is ]
            ]

instance XML.InXML Evidence where
   fromXML xml
      | XML.name xml /= "evidence" = fail "expecting <evidence> tag"
      | otherwise = do
           xs <- mapM getHard (XML.findChildren "hard" xml)
           ys <- mapM getSoft (XML.findChildren "soft" xml)
           return $ E $ M.fromList $ xs ++ ys
    where
      getHard xml = do
         lbl <- XML.findAttribute "label" xml
         value <- XML.findChild "state" xml >>= XML.findAttribute "value" >>= readM
         return $ (lbl, Index value)

      getSoft xml = do
         lbl <- XML.findAttribute "label" xml
         states' <- forM (XML.findChildren "state" xml) $ \child' -> do
               probability <- XML.findAttribute "probability" child' >>= (either fail return . parseSimple pProbability)
               value <- XML.findAttribute "value" child'
               return (value, probability)
         return $ (lbl, Virtual states')