-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.Network
   ( Network, makeNetwork, name, nodes, nodeIds, mapNodes, mapNodesM, renameNodes, filterNodes, addNode
   , Node(..), parents, node, node_, size, sizeId, descendants, ancestors, findNode, findStates
   , Definition(..)
   , findNodeFuzzy
   , state2label
   , label2state
   ) where

import Bayes.Probability
import Control.Arrow
import Data.List
import Data.Maybe
import Data.Semigroup
import qualified Data.Map as M
import qualified Data.Set as S

import Util.String ( normalize )

data Network a = Network
   { name    :: String
   , nodeMap :: M.Map String (Node a)
   }
 deriving (Eq, Ord)

instance Show a => Show (Network a) where
   show nw = unlines $
      (take 60 $ "== " ++ name nw ++ repeat '=') : map show (nodes nw)

instance Semigroup (Network a) where
   nw1 <> nw2 = Network (name nw1 `makeName` name nw2) (M.unionWith overwriteNode (nodeMap nw1) (nodeMap nw2))
    where
      makeName x y
         | null x    = y
         | null y    = x
         | otherwise = x ++ "/" ++ y
      -- nodes can only be merged when the second node has no parents (but only priors)
      overwriteNode n1 n2
         | null (parentIds n2) = n1
         | otherwise = error $ "Node " ++ nodeId n1 ++ " cannot be merged (try changing the order)"

instance Monoid (Network a) where
   mempty  = Network "" M.empty
   mappend = (<>)

instance Functor Network where
   fmap f nw = nw { nodeMap = M.map (fmap f) (nodeMap nw) }

makeNetwork :: String -> [Node a] -> Network a
makeNetwork s ns = Network s (M.fromList [ (nodeId n, n) | n <- ns ])

nodes :: Network a -> [Node a]
nodes = M.elems . nodeMap

nodeIds :: Network a -> S.Set String
nodeIds = M.keysSet . nodeMap

mapNodes :: (Node a -> Node b) -> Network a -> Network b
mapNodes f nw = nw {nodeMap = M.map f (nodeMap nw)}

mapNodesM :: Monad m => (Node a -> m (Node b)) -> Network a -> m (Network b)
mapNodesM f nw = do
   let m = nodeMap nw
   xs <- mapM f (M.elems m)
   return nw {nodeMap = M.fromList $ zip (M.keys m) xs }

renameNodes :: (String -> String) -> Network a -> Network a
renameNodes f nw = makeNetwork (name nw) (map (renameNode f) (nodes nw))

filterNodes :: (Node a -> Bool) -> Network a -> Network a
filterNodes p nw = nw {nodeMap = M.filter p (nodeMap nw)}

addNode :: Node a -> Network a -> Network a
addNode n nw = nw
   { nodeMap = M.insert (nodeId n) n (nodeMap nw)
   }

data Definition
   = CPT [Probability]            -- conditional probability table (standard)
   | NoisyMax [Int] [Probability] -- noisy max, with strengths and network parameters
   | NoisyAdder [Int] [Double] [Probability] -- noisy adder, with distinguished states, weights, and network parameters

data Node a = Node
   { nodeId     :: String
   , label      :: String
   , states     :: [(String, a)]
   , parentIds  :: [String]
   , definition :: Definition
   }

instance Show a => Show (Node a) where
   show n = unlines $
      nodeId n : map f (states n)
    where
      f (s, a) = "  " ++ s ++ ": " ++ show a

instance Eq (Node a) where
   n1 == n2 = nodeId n1 == nodeId n2

instance Ord (Node a) where
   compare n1 n2 = compare (nodeId n1) (nodeId n2)

instance Functor Node where
   fmap f n = n { states = map (second f) (states n) }

size :: Node a -> Int
size = length . states

sizeId :: Network a -> String -> Int
sizeId nw s = maybe 0 size (findNode nw s)

node :: String -> Node a
node s = Node s s [] [] (CPT [])

node_ :: Node a -> Node ()
node_ = fmap (const ())

parents :: Network a -> Node a -> [Node a]
parents nw n = mapMaybe (findNode nw) (parentIds n)

renameNode :: (String -> String) -> Node a -> Node a
renameNode f n = n
   { nodeId    = f (nodeId n)
   , parentIds = map f (parentIds n)
   }

findNode :: Network a -> String -> Maybe (Node a)
findNode nw s = M.lookup s (nodeMap nw)

-- | Find a node based on its name, ignoring capitalisation, whitespace and
-- interpunction, and also ignoring if the network is prefixed with its own name.
findNodeFuzzy :: Monad m => Network a -> String -> m (Node a)
findNodeFuzzy nw targetID = maybe (fail err) return . find predicate . M.elems . nodeMap $ nw

   where

   err = "Could find no node resembling " ++ targetID ++ " in the network " ++ name nw

   predicate :: Node a -> Bool
   predicate n =
      let nodeID = nodeId n
      in  normalize targetID == normalize nodeID ||
          ((name nw `isPrefixOf` nodeID) && normalize (drop (length (name nw) + 1) nodeID) == normalize targetID)



-- | Find all states associated with a certain node.
findStates :: Network a -> String -> [(String, Int)]
findStates nw
   = flip zip [0..]
   . map fst
   . maybe [] states
   . findNode nw


-- | Transform a state index of a node to its string label.
state2label :: Network a -> String -> Int -> String
state2label nw nodeID i
   = (\xs -> if i < length xs then xs !! i else "unknown state #" ++ show i)
   . map fst
   . maybe [] states
   $ findNode nw nodeID

-- | Transform a state node label to its state index.
label2state :: Network a -> String -> String -> Int
label2state nw nodeLabel stateLabel = maybe (error $ "no such state " ++ stateLabel) id $ lookup stateLabel (findStates nw nodeLabel)


ancestors :: Network a -> Node a -> [Node a]
ancestors nw n = nub $ ps ++ concatMap (ancestors nw) ps
 where
   ps = parents nw n

descendants :: Network a -> Node a -> [Node a]
descendants nw = rec [] . return
 where
   rec acc [] = acc
   rec acc (x:xs) =
      let ys = filter (elem x . parents nw) (nodes nw)
      in rec (ys `union` acc) (ys `union` xs)