module Bayes.FactorElimination(
moralGraph
, nodeComparisonForTriangulation
, numberOfAddedEdges
, triangulate
, minimumSpanningTree
, createClusterGraph
, Cluster
, createJunctionTree
, JunctionTree
, collect
, distribute
, posterior
, clearEvidence
, updateEvidence
, junctionTreeProperty_prop
, createVerticesJunctionTree
, VertexCluster
) where
import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits)
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T
import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary
numberOfAddedEdges :: UndirectedGraph g
=> g a b
-> Vertex
-> Int
numberOfAddedEdges g v =
let nodes = fromJust $ neighbors g v
in
length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weight :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Int
weight g v =
factorDimension . fromJust . vertexValue g $ v
(.||.) :: (a -> a -> Ordering)
-> (a -> a -> Ordering)
-> (a -> a -> Ordering)
f .||. g =
\a b -> case f a b of
EQ -> g a b
r -> r
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Vertex
-> Ordering
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq)
fromVertexCluster (VertexCluster s) = s
instance Show VertexCluster where
show (VertexCluster s) = show . Set.toList $ s
triangulate :: Graph g
=> (Vertex -> Vertex -> Ordering)
-> g () b
-> ([VertexCluster],g () b)
triangulate cmp g =
let processAllNodes gsrc gdst l | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst)
| otherwise =
let selectedNode = minimumBy cmp (allVertices gsrc)
theNeighbors = selectedNode : (fromJust $ neighbors gsrc selectedNode)
addEmptyEdge e g = addEdge e () g
(gsrc',gdst') = connectAllNodesWith gsrc gdst addEmptyEdge addEmptyEdge theNeighbors
gsrc'' = removeVertex selectedNode gsrc'
in
processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList $ theNeighbors) : l)
in
processAllNodes g g []
findContainingCluster :: VertexCluster
-> [VertexCluster]
-> (Maybe VertexCluster,[VertexCluster])
findContainingCluster cluster l =
let clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
(prefix,suffix) = break clusterIsNotASubsetOf l
in
case suffix of
[] -> (Nothing,l)
_ -> (Just (head suffix),prefix ++ tail suffix)
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
where
checkIfMaximal reversedPrefix current [] =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> reverse (current:reversedPrefix)
(Just r,l) -> reverse (r:reverse l)
checkIfMaximal reversedPrefix current suffix =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
(Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)
createClusterGraph :: UndirectedGraph g
=> [VertexCluster]
-> g Int VertexCluster
createClusterGraph c =
let numberedClusters = zip c (map Vertex [0..])
addCluster (c,v) g = addVertex v c g
graphWithoutEdges = foldr addCluster emptyGraph numberedClusters
separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
addClusterEdge ((ca,va),(cb,vb)) g = addEdge (edge va vb) (separatorSize ca cb) g
in
foldr addClusterEdge graphWithoutEdges allEdges
data Tree b a = Node a [(b,Tree b a)] deriving(Eq)
standardHaskellTree :: (Show f, Show b) => Tree b (JTNodeValue f) -> T.Tree String
standardHaskellTree n@(Node a []) = T.Node (show $ nodeCluster n) []
standardHaskellTree n@(Node a l) = T.Node (show $ nodeCluster n) (map (standardHaskellTree . snd) l)
standardVertexTree :: Tree () VertexCluster -> T.Tree String
standardVertexTree n@(Node a []) = T.Node (show a) []
standardVertexTree n@(Node a l) = T.Node (show a) (map (standardVertexTree . snd) l)
showFactorsAndEdges :: (Show f, Show b) => Tree b (JTNodeValue f) -> (String -> String)
showFactorsAndEdges n@(Node a []) = (++ show (nodeValueFactor a))
showFactorsAndEdges n@(Node a l) = foldl1 (.) (map factorAndEdge l) . (++ show (nodeValueFactor a))
where
factorAndEdge (s,t) = showFactorsAndEdges t . (++ show s)
instance (Show f ,Show b)=> Show (Tree b (JTNodeValue f)) where
show t = "JUNCTION TREE\n" ++ T.drawTree (standardHaskellTree t) ++ "\n" ++ showFactorsAndEdges t "" ++ "\n------\n"
instance Show (Tree () VertexCluster) where
show t = "JUNCTION TREE\n" ++ T.drawTree (standardVertexTree t) ++ "\n"
instance Functor.Functor (Tree b) where
fmap f (Node a l) = Node (f a) (map (mapEdge f) l)
where
mapEdge f (e,c) = (e, fmap f c)
expand :: UndirectedGraph g
=> g Int f
-> [Edge]
-> [Vertex]
-> [Vertex]
-> [Edge]
expand g theEdges inTree remaining | null remaining = theEdges
| otherwise =
let (treeVertex,outVertex) = maximumBy (compare `on` (edgeCost g)) $ [(vin,vout) | vin <- inTree, vout <-remaining,isLinkedWithAnEdge g vin vout]
in
expand g (edge treeVertex outVertex : theEdges) (outVertex : inTree)
(filter (/= outVertex) remaining)
where
edgeCost g (va,vb) = fromJust $ edgeValue g (edge va vb)
leaf x = Node x []
treeEdge c b = (c,b)
createTreeFromMap :: Vertex
-> Map.Map Vertex [Vertex]
-> Tree () Vertex
createTreeFromMap root m =
let growTree m t@(Node a _) | Map.null m = t
| otherwise =
case Map.lookup a m of
Nothing -> t
Just l -> Node a . map (treeEdge () . growTree m . leaf) $ l
in
growTree m (leaf root)
minimumSpanningTree :: UndirectedGraph g
=> g Int f
-> Tree () f
minimumSpanningTree g =
let startRoot = fromJust $ someVertex g
remainingVertices = filter (/= startRoot) (allVertices g)
foundEdges = expand g [] [startRoot] remainingVertices
m = Map.fromListWith (++) . map ((\(a,b) -> (a,[b])) . edgeEndPoints) $ foundEdges
theTree = createTreeFromMap startRoot m
in
Functor.fmap (fromJust . vertexValue g) theTree
vertexClusterIsContainingFactor :: Factor f => VertexCluster -> f -> Bool
vertexClusterIsContainingFactor c f =
let factorVars = Set.fromList . map variableVertex . factorVariables $ f
in
Set.isSubsetOf factorVars (fromVertexCluster c)
clusterIsContainingVariable :: DV -> Cluster -> Bool
clusterIsContainingVariable v c =
Set.member v (Set.fromList $ fromCluster c)
data Separator f = NoMessage !Cluster
| Collect !Cluster !f
| Distribute !Cluster !f !f
deriving(Eq)
instance Show f => Show (Separator f) where
show (NoMessage c) = "NoMessage: " ++ show c
show (Collect c u) = "Collect: " ++ show c ++ "\n" ++ "\n <----- \n" ++ show u ++ "\n"
show (Distribute c u d) = "Distribute: " ++ show c ++ "\n <----- \n" ++ show u ++ "\n" ++ " -----> \n" ++ show d ++ "\n"
type Evidence f = f
data JTNodeValue f = JTNodeValue !Cluster !(Evidence f) !f deriving(Eq,Show)
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Show)
fromCluster (Cluster s) = Set.toList s
vertexClusterToCluster :: (Factor f , Graph g)
=> g e f
-> VertexCluster
-> Cluster
vertexClusterToCluster g c =
let vertices = Set.toList . fromVertexCluster $ c
variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
in
Cluster . Set.fromList $ variables
clusterVertices :: VertexCluster -> [Vertex]
clusterVertices = Set.toList . fromVertexCluster
findFactorsForCluster :: (Factor f , Graph g)
=> BayesianNetwork g f
-> VertexCluster
-> [f]
findFactorsForCluster g c =
filter (vertexClusterIsContainingFactor c) . mapMaybe (vertexValue g) . clusterVertices $ c
type JunctionTree f = Tree (Separator f) (JTNodeValue f)
mkNodePotential :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f
-> VertexCluster
-> Set.Set Vertex
-> (JTNodeValue f, Set.Set Vertex)
mkNodePotential g c set =
let
foundFactors = findFactorsForCluster g c
vertexForFactors = map (variableVertex . factorMainVariable) foundFactors
isNotUsed (v,f) = Set.member v set
factorsNotYetUsed = filter isNotUsed (zip vertexForFactors foundFactors)
set' = Set.difference set (Set.fromList $ map fst factorsNotYetUsed)
factorsToUse = map snd factorsNotYetUsed
potential = factorProduct factorsToUse
in
(JTNodeValue (vertexClusterToCluster g c) (factorFromScalar 1.0) potential, set')
evidenceForCluster :: Factor f => DVISet Int -> Cluster -> Maybe (Evidence f)
evidenceForCluster assignments cluster@(Cluster c) =
let c' = Set.fromList (map instantiationVariable assignments)
common = Set.intersection c' c
selectedVariables = filter (\c -> Set.member (instantiationVariable c) common) assignments
in
evidenceFrom selectedVariables
nodeCluster :: Tree a (JTNodeValue f) -> Cluster
nodeCluster (Node (JTNodeValue c _ _ ) _) = c
emptyCluster :: Cluster
emptyCluster = Cluster Set.empty
nodeValueFactor (JTNodeValue _ _ f ) = f
nodeValueEvidence (JTNodeValue _ e _) = e
nodeValueWithNewEvidence (JTNodeValue a e b) e' = JTNodeValue a e' b
clearNodeValueEvidence (JTNodeValue a _ b) = JTNodeValue a (factorFromScalar 1.0) b
separatorCluster :: Separator f -> Cluster
separatorCluster (NoMessage c) = c
separatorCluster (Collect c _) = c
separatorCluster (Distribute c _ _) = c
upMessage (Distribute _ u _) = Just u
upMessage (Collect _ u ) = Just u
upMessage _ = Nothing
downMessage (Distribute _ _ d) = Just d
downMessage _ = Nothing
computeSeparatorCluster :: (Factor f, Graph g)
=> BayesianNetwork g f
-> VertexCluster
-> VertexCluster
-> Cluster
computeSeparatorCluster g parent child =
let theNodeCluster (Node c _) = c
childVertices = fromVertexCluster child
parentVertices = fromVertexCluster parent
separatorVertices = VertexCluster $ Set.intersection childVertices parentVertices
in
vertexClusterToCluster g separatorVertices
dfs :: (n -> n -> e -> e')
-> (n -> a -> (n', a))
-> Tree e n
-> a
-> (Tree e' n', a)
dfs edgef nodef n@(Node nodevalue []) current =
let (newnodevalue, newval) = nodef nodevalue current
in
(Node newnodevalue [],newval)
dfs edgef nodef n@(Node nodevalue children) current =
let (newnodevalue, newval) = nodef nodevalue current
applyEdgeFunction (e,Node childvalue _) = edgef nodevalue childvalue e
applyToChildren childrenNode val = dfs edgef nodef childrenNode val
edges' = map applyEdgeFunction children
recurseOnChildren s r [] = (s,reverse r)
recurseOnChildren s r (a:l) =
let (a',s') = applyToChildren a s
in
recurseOnChildren s' (a':r) l
(lastval,newSubTrees) = recurseOnChildren newval [] (map snd children)
in
(Node newnodevalue (zip edges' newSubTrees),lastval)
setFactorEdgeUpdate :: (Graph g, Factor f)
=> BayesianNetwork g f
-> VertexCluster
-> VertexCluster
-> ()
-> Separator f
setFactorEdgeUpdate g parentvalue childvalue _ = NoMessage $ computeSeparatorCluster g parentvalue childvalue
setFactorNodeUpdate :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f
-> VertexCluster
-> Set.Set Vertex
-> (JTNodeValue f, Set.Set Vertex)
setFactorNodeUpdate g nodeValue set = mkNodePotential g nodeValue set
setFactors :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f
-> Tree () VertexCluster
-> Set.Set Vertex
-> (JunctionTree f,Set.Set Vertex)
setFactors g = dfs (setFactorEdgeUpdate g) (setFactorNodeUpdate g)
createVerticesJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g)
=> (UndirectedSG () b -> Vertex -> Vertex -> Ordering)
-> g () b
-> Tree () VertexCluster
createVerticesJunctionTree cmp g =
let theMoralGraph = moralGraph g
(clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph clusters :: UndirectedSG Int VertexCluster
in
minimumSpanningTree g''
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering)
-> BayesianNetwork g f
-> JunctionTree f
createJunctionTree cmp g =
let cTree = createVerticesJunctionTree cmp g
factorSet = Set.fromList (allVertices g)
(newTree, _) = setFactors g cTree factorSet
in
distribute Nothing . collect $ newTree
collectMessages :: Factor f => (Separator f , JunctionTree f) -> (Separator f , JunctionTree f)
collectMessages (separator, Node nc []) =
let sc = separatorCluster separator
newPotential = factorProduct [nodeValueFactor nc,nodeValueEvidence nc]
newMessage = factorProjectTo (fromCluster sc) newPotential
in
(Collect sc newMessage, Node nc [])
collectMessages (separator,(Node nc l)) =
let sc = separatorCluster separator
messagesFromSubTrees = map collectMessages l
newPotential = factorProduct (nodeValueEvidence nc:nodeValueFactor nc:(mapMaybe (upMessage . fst) messagesFromSubTrees))
newMessage = factorProjectTo (fromCluster sc) newPotential
in
(Collect sc newMessage, Node nc messagesFromSubTrees)
collect :: Factor f => JunctionTree f -> JunctionTree f
collect t = let (_,t') = collectMessages (NoMessage emptyCluster, t) in t'
notSameCluster a b = nodeCluster a /= nodeCluster b
distribute :: Factor f => Maybe (Separator f) -> JunctionTree f -> JunctionTree f
distribute down n@(Node nc []) = n
distribute down (Node nc l) =
let receivedDownMessage = if isJust down then fromJust . downMessage . fromJust $ down else factorFromScalar 1.0
getUpMessage (edge,c) = upMessage edge
upMessagesForSendingTo i = fromJust . mapM getUpMessage . filter ((i `notSameCluster`) . snd) $ l
newPotential i = factorProduct (nodeValueFactor nc:nodeValueEvidence nc:receivedDownMessage:upMessagesForSendingTo i)
newMessage sc i = factorProjectTo (fromCluster sc) (newPotential i)
distributeMessage s@(Collect sc dm,i) =
let newSeparator = Distribute sc dm (newMessage sc i)
in
(newSeparator,distribute (Just newSeparator) i)
distributeMessage _ = error "Distribute message can only update a collect phase message"
subTrees = map distributeMessage l
in
Node nc subTrees
findInTree :: (Tree edge a -> Bool) -> Maybe edge -> Tree edge a -> Maybe (Maybe edge,Tree edge a)
findInTree cmp e n@(Node a []) = if (cmp n) then Just (e,n) else Nothing
findInTree cmp e n@(Node a l) =
let findSome [] = Nothing
findSome ((e',h):t) =
case findInTree cmp (Just e') h of
Nothing -> findSome t
Just r -> Just r
in
case cmp n of
True -> Just (e,n)
False -> findSome l
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v = do
(maybeEdge,Node n l) <- findInTree (clusterIsContainingVariable v . nodeCluster) Nothing t
let receivedDownMessage = maybe (factorFromScalar 1.0) id $
do
e <- maybeEdge
downMessage e
upMessages = fromJust . mapM (upMessage . fst) $ l
p = factorProduct (receivedDownMessage:nodeValueEvidence n:nodeValueFactor n:upMessages)
return $ normedFactor $ factorProjectTo [v] p
applyEvidenceWith :: (JunctionTree f -> JunctionTree f)
-> JunctionTree f
-> JunctionTree f
applyEvidenceWith nodeChange n@(Node _ []) = nodeChange n
applyEvidenceWith nodeChange n@(Node _ l) =
let Node n' l' = nodeChange n
changeChildren (e,c) = (e,applyEvidenceWith nodeChange c)
in
Node n' (map changeChildren l')
evidenceWith :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
evidenceWith assignments t@(Node n l) =
let n' = case evidenceForCluster assignments (nodeCluster t) of
Nothing -> n
Just e' -> nodeValueWithNewEvidence n e'
in
Node n' l
clearNodeEvidence (Node n l) = Node (clearNodeValueEvidence n) l
clearEvidence :: Factor f => JunctionTree f -> JunctionTree f
clearEvidence = distribute Nothing . collect . applyEvidenceWith (clearNodeEvidence)
updateEvidence :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
updateEvidence assignments = distribute Nothing . collect . applyEvidenceWith (evidenceWith assignments)
junctionTreeProperty :: [VertexCluster] -> Tree () VertexCluster -> Bool
junctionTreeProperty path (Node _ []) = True
junctionTreeProperty path (Node c l) =
let children = map snd l
in
checkPath c (reverse path) && all (junctionTreeProperty (c:path)) children
junctionTreeProperty_prop :: DirectedSG () String -> Property
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let cmp ug = (compare `on` (numberOfAddedEdges ug))
in
junctionTreeProperty [] (createVerticesJunctionTree cmp g)
checkPath :: VertexCluster -> [VertexCluster] -> Bool
checkPath c l =
let parentSets = map fromVertexCluster l
allIntersections = map (Set.intersection (fromVertexCluster c)) parentSets
pathsToEachParent = tail . inits $ parentSets
isSubsetOfAllParents i parents = all (Set.isSubsetOf i) parents
in
and $ zipWith isSubsetOfAllParents allIntersections pathsToEachParent
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g)
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g)
connectAllNodesWith :: (Graph g, Graph g')
=> g a b
-> g' a b
-> (Edge -> g a b -> g a b)
-> (Edge -> g' a b -> g' a b)
-> [Vertex]
-> (g a b,g' a b)
connectAllNodesWith originGraph dstGraph g f nodes =
let h e (x,y) = (g e x, f e y)
(originGraph',dstGraph') =
foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
in
(originGraph',dstGraph')
addMissingLinks :: DirectedGraph g => Vertex -> b -> g () b -> g () b
addMissingLinks v _ g =
let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v)
in
g'
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
=> g () b
-> g' () b
convertToUndirected m =
let addVertexWithLabel v dat g =
let theName = fromJust $ vertexLabel m v
in
addLabeledVertex theName v dat g
newDiscreteGraph = foldrWithVertex addVertexWithLabel emptyGraph m
addEmptyEdge edge g = addEdge edge () g
in
foldr addEmptyEdge newDiscreteGraph . allEdges $ m
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g)
=> g () b -> UndirectedSG () b
moralGraph g =
convertToUndirected . foldrWithVertex addMissingLinks g $ g