----------------------------------------------------------------------------- -- Copyright 2019, Advise-Me project team. This file is distributed under -- the terms of the Apache License 2.0. For more information, see the files -- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution. ----------------------------------------------------------------------------- -- | -- Maintainer : bastiaan.heeren@ou.nl -- Stability : provisional -- Portability : portable (depends on ghc) -- ----------------------------------------------------------------------------- module Bayes.Network ( Network, makeNetwork, name, nodes, nodeIds, mapNodes, mapNodesM, renameNodes, filterNodes, addNode , Node(..), parents, node, node_, size, sizeId, descendants, ancestors, findNode, findStates , Definition(..) , findNodeFuzzy , state2label , label2state ) where import Bayes.Probability import Control.Arrow import Data.List import Data.Maybe import Data.Semigroup import qualified Data.Map as M import qualified Data.Set as S import Util.String ( normalize ) data Network a = Network { name :: String , nodeMap :: M.Map String (Node a) } deriving (Eq, Ord) instance Show a => Show (Network a) where show nw = unlines $ (take 60 $ "== " ++ name nw ++ repeat '=') : map show (nodes nw) instance Semigroup (Network a) where nw1 <> nw2 = Network (name nw1 `makeName` name nw2) (M.unionWith overwriteNode (nodeMap nw1) (nodeMap nw2)) where makeName x y | null x = y | null y = x | otherwise = x ++ "/" ++ y -- nodes can only be merged when the second node has no parents (but only priors) overwriteNode n1 n2 | null (parentIds n2) = n1 | otherwise = error $ "Node " ++ nodeId n1 ++ " cannot be merged (try changing the order)" instance Monoid (Network a) where mempty = Network "" M.empty mappend = (<>) instance Functor Network where fmap f nw = nw { nodeMap = M.map (fmap f) (nodeMap nw) } makeNetwork :: String -> [Node a] -> Network a makeNetwork s ns = Network s (M.fromList [ (nodeId n, n) | n <- ns ]) nodes :: Network a -> [Node a] nodes = M.elems . nodeMap nodeIds :: Network a -> S.Set String nodeIds = M.keysSet . nodeMap mapNodes :: (Node a -> Node b) -> Network a -> Network b mapNodes f nw = nw {nodeMap = M.map f (nodeMap nw)} mapNodesM :: Monad m => (Node a -> m (Node b)) -> Network a -> m (Network b) mapNodesM f nw = do let m = nodeMap nw xs <- mapM f (M.elems m) return nw {nodeMap = M.fromList $ zip (M.keys m) xs } renameNodes :: (String -> String) -> Network a -> Network a renameNodes f nw = makeNetwork (name nw) (map (renameNode f) (nodes nw)) filterNodes :: (Node a -> Bool) -> Network a -> Network a filterNodes p nw = nw {nodeMap = M.filter p (nodeMap nw)} addNode :: Node a -> Network a -> Network a addNode n nw = nw { nodeMap = M.insert (nodeId n) n (nodeMap nw) } data Definition = CPT [Probability] -- conditional probability table (standard) | NoisyMax [Int] [Probability] -- noisy max, with strengths and network parameters | NoisyAdder [Int] [Double] [Probability] -- noisy adder, with distinguished states, weights, and network parameters data Node a = Node { nodeId :: String , label :: String , states :: [(String, a)] , parentIds :: [String] , definition :: Definition } instance Show a => Show (Node a) where show n = unlines $ nodeId n : map f (states n) where f (s, a) = " " ++ s ++ ": " ++ show a instance Eq (Node a) where n1 == n2 = nodeId n1 == nodeId n2 instance Ord (Node a) where compare n1 n2 = compare (nodeId n1) (nodeId n2) instance Functor Node where fmap f n = n { states = map (second f) (states n) } size :: Node a -> Int size = length . states sizeId :: Network a -> String -> Int sizeId nw s = maybe 0 size (findNode nw s) node :: String -> Node a node s = Node s s [] [] (CPT []) node_ :: Node a -> Node () node_ = fmap (const ()) parents :: Network a -> Node a -> [Node a] parents nw n = mapMaybe (findNode nw) (parentIds n) renameNode :: (String -> String) -> Node a -> Node a renameNode f n = n { nodeId = f (nodeId n) , parentIds = map f (parentIds n) } findNode :: Network a -> String -> Maybe (Node a) findNode nw s = M.lookup s (nodeMap nw) -- | Find a node based on its name, ignoring capitalisation, whitespace and -- interpunction, and also ignoring if the network is prefixed with its own name. findNodeFuzzy :: Monad m => Network a -> String -> m (Node a) findNodeFuzzy nw targetID = maybe (fail err) return . find predicate . M.elems . nodeMap $ nw where err = "Could find no node resembling " ++ targetID ++ " in the network " ++ name nw predicate :: Node a -> Bool predicate n = let nodeID = nodeId n in normalize targetID == normalize nodeID || ((name nw `isPrefixOf` nodeID) && normalize (drop (length (name nw) + 1) nodeID) == normalize targetID) -- | Find all states associated with a certain node. findStates :: Network a -> String -> [(String, Int)] findStates nw = flip zip [0..] . map fst . maybe [] states . findNode nw -- | Transform a state index of a node to its string label. state2label :: Network a -> String -> Int -> String state2label nw nodeID i = (\xs -> if i < length xs then xs !! i else "unknown state #" ++ show i) . map fst . maybe [] states $ findNode nw nodeID -- | Transform a state node label to its state index. label2state :: Network a -> String -> String -> Int label2state nw nodeLabel stateLabel = maybe (error $ "no such state " ++ stateLabel) id $ lookup stateLabel (findStates nw nodeLabel) ancestors :: Network a -> Node a -> [Node a] ancestors nw n = nub $ ps ++ concatMap (ancestors nw) ps where ps = parents nw n descendants :: Network a -> Node a -> [Node a] descendants nw = rec [] . return where rec acc [] = acc rec acc (x:xs) = let ys = filter (elem x . parents nw) (nodes nw) in rec (ys `union` acc) (ys `union` xs)