{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Algorithms for factor elimination

-}
module Bayes.FactorElimination(
    -- * Moral graph
      moralGraph
    -- * Triangulation
    , nodeComparisonForTriangulation
    , numberOfAddedEdges
    , triangulate
    -- * Junction tree
    , minimumSpanningTree
    , createClusterGraph
    , Cluster
    , createJunctionTree
    , JunctionTree
    -- * Shenoy-Shafer message passing
    , collect 
    , distribute
    , posterior 
    -- * Evidence
    , clearEvidence
    , updateEvidence
    -- * Test 
    , junctionTreeProperty_prop
    , createVerticesJunctionTree
    , VertexCluster
    ) where

import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits)
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T 

import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary

--import Debug.Trace
--debug s a = trace (s ++ " " ++ show a ++ "\n") a

{-
 
Comparison functions for graph triangulation

-}

-- | Number of edges added when connecting all neighbors
numberOfAddedEdges :: UndirectedGraph g 
                   => g a b 
                   -> Vertex 
                   -> Int 
numberOfAddedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

-- | Weight of a node
weight :: (UndirectedGraph g, Factor f)
       => g a f 
       -> Vertex 
       -> Int 
weight g v = 
    factorDimension . fromJust . vertexValue g $ v

(.||.) :: (a -> a -> Ordering)
       -> (a -> a -> Ordering) 
       -> (a -> a -> Ordering)
f .||. g = 
    \a b -> case f a b of
              EQ -> g a b 
              r -> r

-- | Node selection comparison function used for triangulating the graph
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
                               => g a f
                               -> Vertex 
                               -> Vertex 
                               -> Ordering 
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))

{-

Graph triangulation

-}

-- | A cluster containing only the vertices and not yet the factors
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq)

fromVertexCluster (VertexCluster s) = s

instance Show VertexCluster where 
    show (VertexCluster s) = show . Set.toList $ s

-- | Triangulate a graph using a cost function
-- The result is the triangulated graph and the list of clusters
-- which may not be maximal.
triangulate :: Graph g
            => (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
            -> g () b
            -> ([VertexCluster],g () b) -- ^ Returns the clusters and the triangulated graph
triangulate cmp g = 
    -- At start, gsrc and gdst are the same
    -- gsrc is modified. It is where vertex elimination is taking place.
    -- The edges are added to gdst
    let processAllNodes gsrc gdst l  | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst)
                                     | otherwise = 
                                            let selectedNode = minimumBy cmp (allVertices gsrc)
                                                theNeighbors = selectedNode : (fromJust $ neighbors gsrc selectedNode)
                                                addEmptyEdge e g = addEdge e () g
                                                (gsrc',gdst') = connectAllNodesWith gsrc gdst addEmptyEdge addEmptyEdge theNeighbors
                                                gsrc'' = removeVertex selectedNode gsrc' 
                                            in 
                                            processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList $ theNeighbors) : l)

    in 
    processAllNodes g g []


-- | Find for a containing cluster. 
findContainingCluster :: VertexCluster -- ^ Cluster processed
                      -> [VertexCluster] -- ^ Cluster list where to look for a containing cluster
                      -> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster
findContainingCluster cluster l = 
  let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
       (prefix,suffix) = break clusterIsNotASubsetOf l
  in 
  case suffix of 
    [] -> (Nothing,l)
    _ -> (Just (head suffix),prefix ++ tail suffix)


-- | Remove clusters already contained in a previous clusters
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
 where 
  checkIfMaximal reversedPrefix current [] = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> reverse (current:reversedPrefix) 
      (Just r,l) -> reverse (r:reverse l)
  checkIfMaximal reversedPrefix current suffix = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
      (Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)


-- | Create the cluster graph
createClusterGraph :: UndirectedGraph g
                   => [VertexCluster] 
                   -> g Int VertexCluster
createClusterGraph c =
  let numberedClusters = zip c (map Vertex [0..])
      addCluster (c,v) g = addVertex v c g
      graphWithoutEdges = foldr addCluster emptyGraph numberedClusters
      separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
      allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
      addClusterEdge ((ca,va),(cb,vb)) g = addEdge (edge va vb) (separatorSize ca cb) g
  in 
  foldr addClusterEdge graphWithoutEdges allEdges


{-

Minimum spanning tree using Prim's algorithm
  
-}

-- | Tree with values on edges
data Tree b a = Node a [(b,Tree b a)] deriving(Eq)

{-

Implementation of show for the tree
 
-}
standardHaskellTree :: (Show f, Show b) => Tree b (JTNodeValue f) -> T.Tree String 
standardHaskellTree n@(Node a []) = T.Node (show $ nodeCluster n) []
standardHaskellTree n@(Node a l) = T.Node (show $ nodeCluster n) (map (standardHaskellTree  . snd) l)

standardVertexTree :: Tree () VertexCluster -> T.Tree String 
standardVertexTree n@(Node a []) = T.Node (show a) []
standardVertexTree n@(Node a l) = T.Node (show a) (map (standardVertexTree  . snd) l)
  
showFactorsAndEdges :: (Show f, Show b) => Tree b (JTNodeValue f) -> (String -> String) 
showFactorsAndEdges  n@(Node a []) = (++ show (nodeValueFactor a))
showFactorsAndEdges  n@(Node a l) = foldl1 (.) (map factorAndEdge l) . (++ show (nodeValueFactor a)) 
  where 
    factorAndEdge (s,t) = showFactorsAndEdges t . (++ show s) 

instance (Show f ,Show b)=> Show (Tree b (JTNodeValue f)) where 
  show t = "JUNCTION TREE\n" ++ T.drawTree (standardHaskellTree t) ++ "\n" ++ showFactorsAndEdges t "" ++ "\n------\n"

instance Show (Tree () VertexCluster) where 
  show t = "JUNCTION TREE\n" ++ T.drawTree (standardVertexTree t) ++ "\n"

instance Functor.Functor (Tree b) where 
  fmap f (Node a l) = Node (f a) (map (mapEdge f) l)
    where 
      mapEdge f (e,c) = (e, fmap f c)

-- | Expand a tree (encoded as a list of edges)
-- by adding vertices and keeping track of the vertices which have
-- already been added.
-- The selection of where to connect the new vertices is based upon cost of the new edges
expand :: UndirectedGraph g 
       => g Int f 
       -> [Edge] -- ^ List of edges
       -> [Vertex] -- ^ Vertices in Tree
       -> [Vertex] -- ^ Vertices to add
       -> [Edge] -- ^ Updated sets and edge list
expand g theEdges inTree remaining | null remaining = theEdges
                                   | otherwise = 
                                        let (treeVertex,outVertex) = maximumBy (compare `on` (edgeCost g)) $ [(vin,vout) | vin <- inTree, vout <-remaining,isLinkedWithAnEdge g vin vout]
                                        in 
                                        expand g (edge treeVertex outVertex : theEdges) (outVertex : inTree)
                                          (filter (/= outVertex) remaining)

  where 
    edgeCost g (va,vb) = fromJust $ edgeValue g (edge va vb)

leaf x = Node x []
treeEdge c b = (c,b)

-- | Create a tree based upon a description with edges
createTreeFromMap :: Vertex -- ^ Root vertex
                  -> Map.Map Vertex [Vertex] -- ^ Tree branches
                  -> Tree () Vertex 
createTreeFromMap root m = 
  let growTree m t@(Node a _) | Map.null m = t
                              | otherwise = 
                                    case Map.lookup a m of 
                                      Nothing -> t 
                                      Just l -> Node a . map (treeEdge () . growTree m . leaf) $ l
  in
  growTree m (leaf root)
                   
-- | Implementing the Prim's algorithm for minimum spanning tree
minimumSpanningTree :: UndirectedGraph g 
                    => g Int f 
                    -> Tree () f 
minimumSpanningTree g = 
  let startRoot = fromJust $ someVertex g 
      remainingVertices = filter (/= startRoot) (allVertices g)
      foundEdges = expand g [] [startRoot] remainingVertices
      m = Map.fromListWith (++) . map ((\(a,b) -> (a,[b])) . edgeEndPoints) $ foundEdges
      theTree = createTreeFromMap startRoot m
  in 
  Functor.fmap (fromJust . vertexValue g) theTree
      
   
{-

Junction tree algorithm

-}

-- | Check if all variables of a factor are included in a cluster
vertexClusterIsContainingFactor :: Factor f => VertexCluster -> f -> Bool 
vertexClusterIsContainingFactor c f = 
  let factorVars = Set.fromList . map variableVertex . factorVariables $ f
  in 
  Set.isSubsetOf factorVars (fromVertexCluster c)

-- | Check if all variables of a factor are included in a cluster
clusterIsContainingVariable :: DV -> Cluster  -> Bool 
clusterIsContainingVariable v c  =  
  Set.member v (Set.fromList $ fromCluster c)

-- | Separator which can be in 3 state depending how many messages have passed through it
data Separator f = NoMessage !Cluster
                 | Collect !Cluster !f 
                 | Distribute !Cluster !f !f -- Upward and downward message
                 deriving(Eq)

instance Show f => Show (Separator f) where 
  show (NoMessage c) = "NoMessage: " ++ show c 
  show (Collect c u) = "Collect: " ++ show c ++ "\n" ++ "\n <----- \n" ++ show u ++ "\n"
  show (Distribute c u d) = "Distribute: " ++ show c ++ "\n <----- \n" ++ show u ++ "\n" ++ " -----> \n" ++ show d ++ "\n"


-- | Evidence if some is used for the node
type Evidence f = f

-- | Evidence for cluster, factor for cluster
data JTNodeValue f = JTNodeValue !Cluster !(Evidence f) !f deriving(Eq,Show)

-- | Cluster of discrete variables.
-- Discrete variables instead of vertices are needed because the
-- factor are using 'DV' and we need to find
-- which factors must be contained in a given cluster.
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Show)

fromCluster (Cluster s) = Set.toList s 

-- | Convert the clusters from vertex to 'DV' clusters
vertexClusterToCluster :: (Factor f , Graph g)
                       => g e f 
                       -> VertexCluster 
                       -> Cluster 
vertexClusterToCluster g c = 
  let vertices = Set.toList . fromVertexCluster $ c
      variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
  in 
  Cluster . Set.fromList $ variables

-- | Vertices contained in a cluster
clusterVertices :: VertexCluster -> [Vertex]
clusterVertices = Set.toList . fromVertexCluster

-- | Find all factors contained in a cluster
findFactorsForCluster :: (Factor f , Graph g)
                      => BayesianNetwork g f
                      -> VertexCluster
                      -> [f]
findFactorsForCluster g c = 
  filter (vertexClusterIsContainingFactor c) . mapMaybe (vertexValue g) . clusterVertices $ c

-- | The junction tree
type JunctionTree f = Tree (Separator f) (JTNodeValue f)

-- | Get the potential for a cluster
mkNodePotential :: (Graph g, Factor f, Show f)
                => BayesianNetwork g f 
                -> VertexCluster 
                -> Set.Set Vertex
                -> (JTNodeValue f, Set.Set Vertex)
mkNodePotential g c set =  
  let -- Factor found in a cluster but they may already be used in another cluster
      foundFactors = findFactorsForCluster g c
      -- Get the vertices for the factor
      vertexForFactors = map (variableVertex . factorMainVariable) foundFactors 
      -- Keep only the factors which are not already used
      isNotUsed (v,f) = Set.member v set
      factorsNotYetUsed = filter isNotUsed (zip vertexForFactors foundFactors)
      set' = Set.difference set (Set.fromList $ map fst factorsNotYetUsed)
      factorsToUse = map snd factorsNotYetUsed
    
      potential = factorProduct factorsToUse
  in 
  (JTNodeValue (vertexClusterToCluster g c) (factorFromScalar 1.0) potential, set')

-- | Generate the evidence potential for a given cluster
evidenceForCluster :: Factor f => DVISet Int -> Cluster -> Maybe (Evidence f)
evidenceForCluster assignments cluster@(Cluster c) = 
  let c' = Set.fromList (map instantiationVariable assignments) 
      common = Set.intersection c' c 
      selectedVariables = filter (\c -> Set.member (instantiationVariable c) common) assignments
  in 
  evidenceFrom selectedVariables


-- | Get the cluster for a node
nodeCluster :: Tree a (JTNodeValue f) -> Cluster 
nodeCluster (Node (JTNodeValue c _ _ ) _) = c 

emptyCluster :: Cluster 
emptyCluster = Cluster Set.empty

nodeValueFactor (JTNodeValue _ _ f ) = f
nodeValueEvidence (JTNodeValue _ e _) = e

nodeValueWithNewEvidence (JTNodeValue a e b) e' = JTNodeValue a e' b
clearNodeValueEvidence (JTNodeValue a _ b)  = JTNodeValue a (factorFromScalar 1.0) b

-- | Get the cluster for a separator
separatorCluster :: Separator f -> Cluster 
separatorCluster (NoMessage c) = c
separatorCluster (Collect c _) = c 
separatorCluster (Distribute c _ _) = c 


upMessage (Distribute _ u _) = Just u 
upMessage (Collect _ u ) = Just u 
upMessage _ = Nothing 

downMessage (Distribute _ _ d) = Just d 
downMessage _ = Nothing 

computeSeparatorCluster :: (Factor f, Graph g) 
                        => BayesianNetwork g f 
                        -> VertexCluster 
                        -> VertexCluster
                        -> Cluster
computeSeparatorCluster g parent child = 
  let theNodeCluster (Node c _) = c 
      childVertices = fromVertexCluster child 
      parentVertices = fromVertexCluster parent
      separatorVertices = VertexCluster $ Set.intersection childVertices parentVertices
  in
  vertexClusterToCluster g  separatorVertices

dfs :: (n -> n -> e -> e') -- Parent, child node and their egde
    -> (n -> a -> (n', a)) -- Node and current value -> new value and new nod
    -> Tree e n  -- Tree to traverse
    -> a -- Start value
    -> (Tree e' n', a) -- New tree and new value
dfs edgef nodef n@(Node nodevalue []) current = 
  let (newnodevalue, newval) = nodef nodevalue current 
  in 
  (Node newnodevalue [],newval) 
dfs edgef nodef n@(Node nodevalue children) current =
  let (newnodevalue, newval) = nodef nodevalue current 
      applyEdgeFunction (e,Node childvalue _) = edgef nodevalue childvalue e
      applyToChildren childrenNode val = dfs edgef nodef childrenNode val
      edges' = map applyEdgeFunction children
      recurseOnChildren s r [] = (s,reverse r)
      recurseOnChildren s r (a:l) = 
        let (a',s') = applyToChildren a s
        in 
        recurseOnChildren s' (a':r) l 
      (lastval,newSubTrees) = recurseOnChildren newval [] (map snd children)
  in 
  (Node newnodevalue (zip edges' newSubTrees),lastval)

setFactorEdgeUpdate :: (Graph g, Factor f) 
                    => BayesianNetwork g f 
                    -> VertexCluster 
                    -> VertexCluster
                    -> () 
                    -> Separator f
setFactorEdgeUpdate g parentvalue childvalue _ = NoMessage $ computeSeparatorCluster g parentvalue childvalue 

setFactorNodeUpdate :: (Graph g, Factor f, Show f) 
                    => BayesianNetwork g f 
                    -> VertexCluster
                    -> Set.Set Vertex 
                    -> (JTNodeValue f, Set.Set Vertex)
setFactorNodeUpdate g nodeValue set = mkNodePotential g nodeValue set

-- | Set a factor for a node
setFactors :: (Graph g, Factor f, Show f)
           => BayesianNetwork g f -- ^ Bayesian graph
           -> Tree () VertexCluster  -- ^ Cluster tree with no factors
           -> Set.Set Vertex
           -> (JunctionTree f,Set.Set Vertex) -- ^ Initialized junction tree
setFactors g = dfs (setFactorEdgeUpdate g) (setFactorNodeUpdate g) 


-- | Create a junction tree with only the clusters and no factors
createVerticesJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g)
                           => (UndirectedSG () b -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                           -> g () b -- ^ Input directed graph
                           -> Tree () VertexCluster -- ^ Junction tree
createVerticesJunctionTree cmp g =  
  let theMoralGraph = moralGraph g
      (clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph
      g'' = createClusterGraph clusters :: UndirectedSG Int VertexCluster
  in 
  minimumSpanningTree g''

-- | Create a function tree
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
                  => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                  -> BayesianNetwork g f -- ^ Input directed graph
                  -> JunctionTree f -- ^ Junction tree
createJunctionTree cmp g = 
  let cTree = createVerticesJunctionTree cmp g 
      factorSet = Set.fromList (allVertices g) -- Tracking of factors which have not yet been put in the junction tree
      -- A vertex is linked with a factor so vertex is used as the identifier
      (newTree, _) = setFactors g cTree factorSet
  in 
  distribute Nothing . collect $ newTree


collectMessages :: Factor f => (Separator f , JunctionTree f) -> (Separator f , JunctionTree f)
collectMessages (separator, Node nc []) = 
  let sc = separatorCluster separator
      newPotential = factorProduct [nodeValueFactor nc,nodeValueEvidence nc] 
      newMessage = factorProjectTo (fromCluster sc) newPotential
  in
  (Collect sc newMessage, Node nc []) -- Copy node factor to node current potential
collectMessages (separator,(Node nc l)) = 
  let sc = separatorCluster separator
      messagesFromSubTrees = map collectMessages l 
      newPotential = factorProduct (nodeValueEvidence nc:nodeValueFactor nc:(mapMaybe (upMessage . fst) messagesFromSubTrees))
      newMessage = factorProjectTo (fromCluster sc) newPotential 
  in 
  (Collect sc newMessage, Node nc messagesFromSubTrees)

-- | Collect phase of the junction tree
collect :: Factor f => JunctionTree f -> JunctionTree f 
collect t = let (_,t') = collectMessages (NoMessage emptyCluster, t) in t'

notSameCluster a b = nodeCluster a /= nodeCluster b 

-- | Distribute phase of the junction tree
distribute :: Factor f => Maybe (Separator f) -> JunctionTree f -> JunctionTree f 
distribute down n@(Node nc []) = n
distribute down (Node nc l) = 
  let receivedDownMessage = if isJust down then fromJust . downMessage . fromJust $ down else factorFromScalar 1.0
      getUpMessage (edge,c) = upMessage edge 
      upMessagesForSendingTo i = fromJust . mapM getUpMessage . filter ((i `notSameCluster`) . snd) $ l
      newPotential i = factorProduct (nodeValueFactor nc:nodeValueEvidence nc:receivedDownMessage:upMessagesForSendingTo i)
      newMessage sc i = factorProjectTo (fromCluster sc) (newPotential i)
      distributeMessage s@(Collect sc dm,i) = 
        let newSeparator = Distribute sc dm (newMessage sc i)
        in 
        (newSeparator,distribute (Just newSeparator) i)
      distributeMessage _ = error "Distribute message can only update a collect phase message"
      subTrees = map distributeMessage l
  in 
  Node nc subTrees

-- | Depth first search in  tree
findInTree :: (Tree edge a -> Bool) -> Maybe edge -> Tree edge a -> Maybe (Maybe edge,Tree edge a)
findInTree cmp e n@(Node a []) = if (cmp n) then Just (e,n) else Nothing 
findInTree cmp e n@(Node a l) = 
  let findSome [] = Nothing
      findSome ((e',h):t) = 
        case findInTree cmp (Just e') h of 
          Nothing -> findSome t 
          Just r -> Just r
  in
  case cmp n of 
    True -> Just (e,n) 
    False -> findSome l


-- | Compute the marginal posterior (if some evidence is set on the junction tree)
  -- otherwise compute just the marginal prior.
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v = do 
  (maybeEdge,Node n l) <- findInTree (clusterIsContainingVariable v . nodeCluster) Nothing t
  let receivedDownMessage = maybe (factorFromScalar 1.0) id $ 
                               do
                                 e <- maybeEdge 
                                 downMessage e
      upMessages = fromJust . mapM (upMessage . fst) $ l
      p = factorProduct (receivedDownMessage:nodeValueEvidence n:nodeValueFactor n:upMessages)
  return $ normedFactor $ factorProjectTo [v] p 

-- | Apply some evidence modifications in the tree
applyEvidenceWith :: (JunctionTree f -> JunctionTree f) -- ^ Node modification function. Only change node value. Not the children
                  -> JunctionTree f -- ^ Input tree
                  -> JunctionTree f
applyEvidenceWith nodeChange n@(Node _ []) = nodeChange n 
applyEvidenceWith nodeChange n@(Node _ l) =
  let Node n' l' = nodeChange n 
      changeChildren (e,c) = (e,applyEvidenceWith nodeChange c)
  in 
  Node n' (map changeChildren l')

-- | Change the evidence for a node
evidenceWith :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
evidenceWith assignments t@(Node n l) = 
  let n' = case evidenceForCluster assignments (nodeCluster t) of 
             Nothing -> n 
             Just e' -> nodeValueWithNewEvidence n e'
  in 
  Node n' l

-- | Remove the evidence for a node
clearNodeEvidence (Node n l) = Node (clearNodeValueEvidence n) l 

-- | Remove evidence in the junction tree
clearEvidence :: Factor f => JunctionTree f -> JunctionTree f
clearEvidence = distribute Nothing . collect . applyEvidenceWith (clearNodeEvidence)

-- | Update evidence in the tree
updateEvidence :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
updateEvidence assignments = distribute Nothing . collect . applyEvidenceWith (evidenceWith assignments)

-- | Used to implement quickcheck.
-- The junction tree property is the property that CA intersection CB is included in all clusters in the path
-- from CA to CB.
junctionTreeProperty :: [VertexCluster] -> Tree () VertexCluster -> Bool
junctionTreeProperty path (Node _ []) = True 
junctionTreeProperty path (Node c l) = 
  let children = map snd l 
  in
  checkPath c (reverse path) && all (junctionTreeProperty (c:path)) children 

junctionTreeProperty_prop :: DirectedSG () String -> Property 
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
  let cmp ug = (compare `on` (numberOfAddedEdges ug))
  in
  junctionTreeProperty [] (createVerticesJunctionTree cmp g)

-- | Check that the intersection of C with any parent in included in any cluster between the parent and C.
checkPath :: VertexCluster -> [VertexCluster] -> Bool 
checkPath c l = 
  let parentSets = map fromVertexCluster l
      allIntersections = map (Set.intersection (fromVertexCluster c)) parentSets
      pathsToEachParent = tail . inits $ parentSets
      isSubsetOfAllParents i parents = all (Set.isSubsetOf i) parents
  in    
  and $ zipWith isSubsetOfAllParents allIntersections pathsToEachParent
{-

Moral graph

-}
-- | Get the parents of a vertex
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g) 

-- | Get the children of a vertex
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g) 

-- | Connect all the nodes which are not connected and apply the function f for each new connection
-- The origin and dest graph must share the same vertex.
connectAllNodesWith :: (Graph g, Graph g') 
                    => g a b -- ^ Graph containing the nodes
                    -> g' a b -- ^ Graph to be modified
                    -> (Edge -> g a b -> g a b) -- ^ Function used to modify the source graph
                    -> (Edge -> g' a b -> g' a b) -- ^ Function used to modify a new graph
                    -> [Vertex]  -- ^ List of nodes to connect
                    -> (g a b,g' a b) -- ^ Result graph
connectAllNodesWith originGraph dstGraph g f nodes  =  
    let h e (x,y) = (g e x, f e y)
        (originGraph',dstGraph') = 
           foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
    in 
    (originGraph',dstGraph')

-- | Add the missing parent links
addMissingLinks :: DirectedGraph g => Vertex -> b -> g () b -> g () b
addMissingLinks v _ g = 
    let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v)
    in 
    g'


-- | Convert the graph to an undirected form
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
                    => g  () b 
                    -> g' () b 
convertToUndirected m = 
    let addVertexWithLabel v dat g = 
           let theName = fromJust $ vertexLabel m v
           in 
           addLabeledVertex theName v dat g
        newDiscreteGraph = foldrWithVertex addVertexWithLabel emptyGraph m
        addEmptyEdge edge g = addEdge edge () g
    in 
    foldr addEmptyEdge newDiscreteGraph . allEdges $ m

-- | For the junction tree construction, only the vertices are needed during the intermediate steps.
-- So, the moral graph is returned without any vertex data.
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g) 
           => g () b -> UndirectedSG () b 
moralGraph g = 
    convertToUndirected  . foldrWithVertex addMissingLinks g $ g