module Bayes.FactorElimination(
moralGraph
, nodeComparisonForTriangulation
, numberOfAddedEdges
, weight
, weightedEdges
, triangulate
, createClusterGraph
, Cluster
, createJunctionTree
, createUninitializedJunctionTree
, JunctionTree
, collect
, distribute
, posterior
, changeEvidence
, junctionTreeProperty_prop
, VertexCluster
, junctionTreeProperty
, maximumSpanningTree
, fromVertexCluster
) where
import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM,guard)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl')
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T
import Bayes.FactorElimination.JTree
import Control.Applicative((<$>))
import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary
numberOfAddedEdges :: UndirectedGraph g
=> g a b
-> Vertex
-> Int
numberOfAddedEdges g v =
let nodes = fromJust $ neighbors g v
in
length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weightedEdges :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Int
weightedEdges g v =
let nodes = fromJust $ neighbors g v
in
sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weight :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Int
weight g v =
factorDimension . fromJust . vertexValue g $ v
(.||.) :: (a -> a -> Ordering)
-> (a -> a -> Ordering)
-> (a -> a -> Ordering)
f .||. g =
\a b -> case f a b of
EQ -> g a b
r -> r
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Vertex
-> Ordering
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq)
fromVertexCluster (VertexCluster s) = s
instance Show VertexCluster where
show (VertexCluster s) = show . Set.toList $ s
instance IsCluster Cluster where
overlappingEvidence (Cluster c) e = filter (\x -> Set.member (instantiationVariable x) c) e
clusterVariables (Cluster s) = Set.toList s
mkSeparator (Cluster sa) (Cluster sb) = Cluster $ Set.intersection sa sb
triangulate :: Graph g
=> (Vertex -> Vertex -> Ordering)
-> g () b
-> ([VertexCluster],g () b)
triangulate cmp g =
let processAllNodes gsrc gdst l | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst)
| otherwise =
let selectedNode = minimumBy cmp (allVertices gsrc)
theNeighbors = selectedNode : (fromJust $ neighbors gsrc selectedNode)
addEmptyEdge e g = addEdge e () g
(gsrc',gdst') = connectAllNodesWith gsrc gdst addEmptyEdge addEmptyEdge theNeighbors
gsrc'' = removeVertex selectedNode gsrc'
in
processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList $ theNeighbors) : l)
in
processAllNodes g g []
findContainingCluster :: VertexCluster
-> [VertexCluster]
-> (Maybe VertexCluster,[VertexCluster])
findContainingCluster cluster l =
let clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
(prefix,suffix) = break clusterIsNotASubsetOf l
in
case suffix of
[] -> (Nothing,l)
_ -> (Just (head suffix),prefix ++ tail suffix)
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
where
checkIfMaximal reversedPrefix current [] =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> reverse (current:reversedPrefix)
(Just r,l) -> reverse (r:reverse l)
checkIfMaximal reversedPrefix current suffix =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
(Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)
vertexClusterToCluster :: (Factor f , Graph g)
=> g e f
-> VertexCluster
-> Cluster
vertexClusterToCluster g c =
let vertices = Set.toList . fromVertexCluster $ c
variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
in
Cluster . Set.fromList $ variables
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
=> g' e f
-> [VertexCluster]
-> g Int Cluster
createClusterGraph bn c =
let numberedClusters = zip c (map Vertex [0..])
addCluster g (c,v) = addVertex v (vertexClusterToCluster bn c) g
graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
in
foldl' addClusterEdge graphWithoutEdges allEdges
possibilities :: (Ord c , UndirectedGraph g)
=> g Int c
-> JTree c (Vertex,f)
-> [Vertex]
-> [c]
-> [(Vertex,c,Int)]
possibilities g currentT remaining leavesClusters = do
rv <- remaining
lv <- leavesClusters
let NodeValue (lvVertex,lvCluster) _ = nodeValue currentT lv
guard (isLinkedWithAnEdge g rv lvVertex)
let ev = fromJust $ edgeValue g (edge rv lvVertex)
return $ (rv,lv,ev)
findMax :: (UndirectedGraph g, Ord c, Factor f)
=> g Int c
-> [Vertex]
-> JTree c (Vertex,f)
-> ([Vertex],(Vertex,c),c)
findMax g remaining currentT =
let leavesClusters = treeNodes currentT
edgeValue (_,_,e) = e
(rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
remaining' = filter (/= rf) remaining
foundCluster = fromJust $ vertexValue g rf
in
(remaining', (rf, foundCluster), lf)
removeVertices :: JTree c (Vertex,f) -> JTree c f
removeVertices t = t { nodeValueMap = Map.map removeVertexFromNode (nodeValueMap t)
, separatorValueMap = Map.map removeVertexFromSeparator (separatorValueMap t)
}
where
removeVertexFromNode (NodeValue (_,f) (_,e)) = NodeValue f e
removeVertexFromSeparator (SeparatorValue (_,u) (Just (_,d))) = SeparatorValue u (Just d)
removeVertexFromSeparator (SeparatorValue (_,u) Nothing) = SeparatorValue u Nothing
removeVertexFromSeparator EmptySeparator = EmptySeparator
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c)
=> g Int c
-> JTree c f
maximumSpanningTree g =
let rootNodeVertex = fromJust $ someVertex g
rootNodeValue = fromJust $ vertexValue g rootNodeVertex
unitFactor = factorFromScalar 1.0
startTree = singletonTree rootNodeValue (rootNodeVertex,unitFactor) (rootNodeVertex,unitFactor)
remainingVertices = filter (/= rootNodeVertex) (allVertices g)
in
removeVertices $ buildTree g remainingVertices startTree
buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c)
=> g Int c
-> [Vertex]
-> JTree c (Vertex,f)
-> JTree c (Vertex,f)
buildTree g [] currentT = currentT
buildTree g l@(h:t) currentT =
let unitFactor = factorFromScalar 1.0
(l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
sep = mkSeparator foundElemValue leaf
newTree = addSeparator leaf sep foundElemValue .
addNode foundElemValue (foundElemVertex,unitFactor) (foundElemVertex,unitFactor) $ currentT
in
buildTree g l' newTree
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering)
-> g () f
-> JunctionTree f
createUninitializedJunctionTree cmp g =
let theMoralGraph = moralGraph g
(clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
in
maximumSpanningTree g''
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering)
-> BayesianNetwork g f
-> JunctionTree f
createJunctionTree cmp g =
let cTree = createUninitializedJunctionTree cmp g
newTree = setFactors g cTree
in
distribute . collect $ newTree
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v =
case snd $ traverseTree (findClusterFor v) Nothing t of
Nothing -> Nothing
Just c -> let NodeValue f e = nodeValue t c
d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c)
u = map (upMessage t) (nodeChildren t c)
unNormalized = factorProjectTo [v] (factorProduct (f:e:d:u))
in
Just $ factorDivide unNormalized (factorNorm unNormalized)
findClusterFor :: DV
-> Maybe Cluster
-> Cluster
-> NodeValue f
-> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v =
case Set.member dv sc of
False -> Skip s
True -> Stop (Just c)
junctionTreeProperty_prop :: DirectedSG () CPT -> Property
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let cmp ug = (compare `on` (numberOfAddedEdges ug))
t = createUninitializedJunctionTree cmp g
in
junctionTreeProperty t [] (root t)
junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c =
let cl = map (separatorChild t) . nodeChildren t $ c
in
checkPath c path && all (junctionTreeProperty t (c:path)) cl
checkPath :: Cluster -> [Cluster] -> Bool
checkPath _ [] = True
checkPath (Cluster c) l =
let clusterSet (Cluster s) = s
parentSets = map clusterSet l
allIntersectionsWithParents = map (Set.intersection c) parentSets
pathsToEachParent = tail . inits $ parentSets
isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
in
and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g)
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g)
connectAllNodesWith :: (Graph g, Graph g')
=> g a b
-> g' a b
-> (Edge -> g a b -> g a b)
-> (Edge -> g' a b -> g' a b)
-> [Vertex]
-> (g a b,g' a b)
connectAllNodesWith originGraph dstGraph g f nodes =
let h e (x,y) = (g e x, f e y)
(originGraph',dstGraph') =
foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
in
(originGraph',dstGraph')
addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b
addMissingLinks g v _ =
let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v)
in
g'
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
=> g () b
-> g' () b
convertToUndirected m =
let addVertexWithLabel g v dat =
let theName = fromJust $ vertexLabel m v
in
addLabeledVertex theName v dat g
newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m
addEmptyEdge edge g = addEdge edge () g
in
foldr addEmptyEdge newDiscreteGraph . allEdges $ m
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g)
=> g () b -> UndirectedSG () b
moralGraph g =
convertToUndirected . foldlWithVertex' addMissingLinks g $ g