```{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Algorithms for factor elimination

-}
module Bayes.FactorElimination(
-- * Moral graph
moralGraph
-- * Triangulation
, nodeComparisonForTriangulation
, weight
, weightedEdges
, triangulate
-- * Junction tree
, createClusterGraph
, Cluster
, createJunctionTree
, createUninitializedJunctionTree
, JunctionTree
, displayTreeValues
-- * Shenoy-Shafer message passing
, collect
, distribute
, posterior
-- * Evidence
, changeEvidence
-- * Test
, junctionTreeProperty_prop
, junctionTreeAllClusters_prop
, VertexCluster
-- * For debug
, junctionTreeProperty
, maximumSpanningTree
, fromVertexCluster
, triangulatedebug
) where

import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl',nub,(\\))
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T
import Bayes.FactorElimination.JTree
import Control.Applicative((<\$>))
import Bayes.VariableElimination(marginal)

import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary

--import Debug.Trace
--debug s a = trace (s ++ "\n" ++ show a ++ "\n") a

{-

Comparison functions for graph triangulation

-}

-- | Number of edges added when connecting all neighbors
=> g a b
-> Vertex
-> Integer
let nodes = fromJust \$ neighbors g v
in
fromIntegral \$ length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

weightedEdges :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Integer
weightedEdges g v =
let nodes = fromJust \$ neighbors g v
in
sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

-- | Weight of a node
weight :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Integer
weight g v =
fromIntegral \$ factorDimension . fromJust . vertexValue g \$ v

(.||.) :: (a -> a -> Ordering)
-> (a -> a -> Ordering)
-> (a -> a -> Ordering)
f .||. g =
\a b -> case f a b of
EQ -> g a b
r -> r

-- | Node selection comparison function used for triangulating the graph
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Vertex
-> Ordering
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))

{-

Graph triangulation

-}

-- | A cluster containing only the vertices and not yet the factors
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq,Ord)

fromVertexCluster (VertexCluster s) = s

instance Show VertexCluster where
show (VertexCluster s) = show . Set.toList \$ s

-- | Triangulate a graph using a cost function
-- The result is the triangulated graph and the list of clusters
-- which may not be maximal.
triangulate :: Graph g
=> (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
-> g () b
-> [VertexCluster] -- ^ Returns the clusters and the triangulated graph
triangulate cmp gr = removeNodes cmp gr []
where
removeNodes cmp g l | hasNoVertices g = keepMaximalClusters (reverse l)
| otherwise =
let selectedNode = minimumBy cmp (allVertices g)
theNeighbors = fromJust \$ neighbors g selectedNode
g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors \$ g
newCluster = VertexCluster . Set.fromList \$ (selectedNode:theNeighbors)
in
removeNodes cmp g' (newCluster:l)

triangulatedebug :: Graph g
=> (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
-> g () b
-> ([VertexCluster],[g () b]) -- ^ Returns the clusters and the triangulated graph
triangulatedebug cmp gr = removeNodes cmp gr [] []
where
removeNodes cmp g l gl | hasNoVertices g = (reverse l,reverse gl)
| otherwise =
let selectedNode = minimumBy cmp (allVertices g)
theNeighbors = fromJust \$ neighbors g selectedNode
g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors \$ g
newCluster = VertexCluster . Set.fromList \$ (selectedNode:theNeighbors)
in
removeNodes cmp g' (newCluster:l) (g:gl)

-- | Find for a containing cluster.
findContainingCluster :: VertexCluster -- ^ Cluster processed
-> [VertexCluster] -- ^ Cluster list where to look for a containing cluster
-> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster
findContainingCluster cluster l =
let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
(prefix,suffix) = break clusterIsNotASubsetOf l
in
case suffix of
[] -> (Nothing,l)
_ -> (Just (head suffix),prefix ++ tail suffix)

-- | Remove clusters already contained in a previous clusters
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
where
checkIfMaximal reversedPrefix current [] =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> reverse (current:reversedPrefix)
(Just r,l) -> reverse (r:reverse l)
checkIfMaximal reversedPrefix current suffix =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
(Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)

-- | Convert the clusters from vertex to 'DV' clusters
vertexClusterToCluster :: (Factor f , Graph g)
=> g e f
-> VertexCluster
-> Cluster
vertexClusterToCluster g c =
let vertices = Set.toList . fromVertexCluster \$ c
variables = map factorMainVariable . mapMaybe (vertexValue g) \$ vertices
in
Cluster . Set.fromList \$ variables

-- | Create the cluster graph
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
=> g' e f
-> [VertexCluster]
-> g Int Cluster
createClusterGraph bn c =
let numberedClusters = zip c (map Vertex [0..])
graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
separatorSize ca cb = Set.size \$ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
in

{-

Maximum spanning tree using Prim's algorithm

-}

-- | Get all possible edges between the leaves and the remaining nodes
possibilities :: (Ord c , UndirectedGraph g)
=> g Int c -- ^ Original graph to get the edge value
-> JTree c f -- ^ Tree to get the vertex for a leaf
-> [Vertex] -- ^ Vertices to add to the tree
-> [c] -- ^ List of leaves
-> [(Vertex,c,Int)] -- ^ Found edge to add
possibilities g currentT remaining leavesClusters = do
rv <- remaining
lv <- leavesClusters
let NodeValue lvVertex lvCluster _ = nodeValue currentT lv
let ev = fromJust \$ edgeValue g (edge rv lvVertex)
return \$ (rv,lv,ev)

-- | Find the max edge to add to the tree
findMax :: (UndirectedGraph g, Ord c, Factor f,Show c)
=> g Int c -- ^ Graph
-> [Vertex] -- ^ Nodes to add
-> JTree c f
-> ([Vertex],(Vertex,c),c)
findMax g remaining currentT =
let leavesClusters = treeNodes currentT
edgeValue (_,_,e) = e
(rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
remaining' = filter (/= rf) remaining
foundCluster = fromJust \$ vertexValue g rf
in
(remaining', (rf, foundCluster), lf)

-- | Implementing the Prim's algorithm for minimum spanning tree
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c, Show c, Show f)
=> g Int c
-> JTree c f
maximumSpanningTree g =
let rootNodeVertex = fromJust \$ someVertex g
rootNodeValue = fromJust \$ vertexValue g rootNodeVertex
startTree = singletonTree rootNodeValue rootNodeVertex [] []
remainingVertices = filter (/= rootNodeVertex) (allVertices g)
in
buildTree g remainingVertices startTree

buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c, Show c, Show f)
=> g Int c
-> [Vertex]
-> JTree c f
-> JTree c f
buildTree g [] currentT = currentT
buildTree g l currentT =
let (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
sep = mkSeparator foundElemValue leaf
newTree = addSeparator leaf sep foundElemValue .
addNode foundElemValue foundElemVertex [] [] \$ currentT
in
buildTree g l' newTree

{-

Junction tree algorithm

-}

-- | Create a junction tree with only the clusters and no factors
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
-> g () f -- ^ Input directed graph
-> JunctionTree f -- ^ Junction tree
createUninitializedJunctionTree cmp g =
let theMoralGraph = moralGraph g
clusters = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
in
maximumSpanningTree g''

-- | Create a function tree
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
-> BayesianNetwork g f -- ^ Input directed graph
-> JunctionTree f -- ^ Junction tree
createJunctionTree cmp g =
let cTree = createUninitializedJunctionTree cmp g
-- A vertex is linked with a factor so vertex is used as the identifier
newTree = setFactors g cTree
in
distribute . collect \$ newTree

-- | Compute the marginal posterior (if some evidence is set on the junction tree)
-- otherwise compute just the marginal prior.
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v =
case snd \$ traverseTree (findClusterFor v) Nothing t of
Nothing -> Nothing
Just c -> let NodeValue ver f e = nodeValue t c
d = maybe (factorFromScalar 1.0) id \$ downMessage t =<< (nodeParent t c)
u = map (upMessage t) (nodeChildren t c)
allFactors = d:u ++ f ++ e
variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ [v]
unNormalized = marginal allFactors variablesToRemove [v] []
in
Just \$ factorDivide unNormalized (factorNorm unNormalized)

-- | Find a cluster containing the variable
findClusterFor :: DV
-> Maybe Cluster
-> Cluster -- ^ Current cluster
-> NodeValue f -- ^ Current value
-> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v =
case Set.member dv sc of
False -> Skip s
True -> Stop (Just c)

junctionTreeProperty_prop :: DirectedSG () CPT -> Property
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let cmp ug = (compare `on` (numberOfAddedEdges ug))
t = createUninitializedJunctionTree cmp g
in
junctionTreeProperty t [] (root t)

junctionTreeAllClusters_prop :: DirectedSG () CPT -> Property
junctionTreeAllClusters_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let theMoralGraph = moralGraph g
cmp ug = (compare `on` (numberOfAddedEdges ug))
clusters = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
jt = maximumSpanningTree g'' :: JunctionTree CPT
treeClusters = treeNodes jt
sa = Set.fromList (map (vertexClusterToCluster g) clusters)
sb = Set.fromList treeClusters
in
Set.isSubsetOf sa sb && Set.isSubsetOf sb sa

junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c =
let cl = map (separatorChild t) . nodeChildren t \$ c
in
checkPath c path && all (junctionTreeProperty t (c:path)) cl

-- | Check that the intersection of C with any parent in included in all cluster between the parent and C.
checkPath :: Cluster -> [Cluster] -> Bool
checkPath _ [] = True
checkPath (Cluster c) l =
let clusterSet (Cluster s) = s -- x
parentSets = map clusterSet l -- Example a b c d where a is the root
allIntersectionsWithParents = map (Set.intersection c) parentSets -- a ^ x, b ^ x , c ^ x , d ^ x
pathsToEachParent = tail . inits \$ parentSets -- a, ab, abc, abcd
isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
in
and \$ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
{-

Moral graph

-}
-- | Get the parents of a vertex
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust \$ ingoing g v >>= mapM (startVertex g)

-- | Get the children of a vertex
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust \$ outgoing g v >>= mapM (endVertex g)

-- | Connect all the nodes which are not connected and apply the function f for each new connection
-- The origin and dest graph must share the same vertex.
=> [Vertex]  -- ^ List of nodes to connect
-> g () b -- ^ Graph containing the nodes
-> g () b
in
foldl' addEmptyEdge originGraph [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]

addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b

-- | Convert the graph to an undirected form
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
=> g  () b
-> g' () b
convertToUndirected m =
let addVertexWithLabel g v dat  =
let theName = fromJust \$ vertexLabel m v
in