```{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Algorithms for factor elimination

-}
module Bayes.FactorElimination(
-- * Moral graph
moralGraph
-- * Triangulation
, nodeComparisonForTriangulation
, triangulate
-- * Junction tree
, minimumSpanningTree
, createClusterGraph
, Cluster
, createJunctionTree
, JunctionTree
-- * Shenoy-Shafer message passing
, collect
, distribute
, posterior
-- * Evidence
, clearEvidence
, updateEvidence
-- * Test
, junctionTreeProperty_prop
, createVerticesJunctionTree
, VertexCluster
) where

import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits)
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T

import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary

--import Debug.Trace
--debug s a = trace (s ++ " " ++ show a ++ "\n") a

{-

Comparison functions for graph triangulation

-}

-- | Number of edges added when connecting all neighbors
=> g a b
-> Vertex
-> Int
let nodes = fromJust \$ neighbors g v
in
length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

-- | Weight of a node
weight :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Int
weight g v =
factorDimension . fromJust . vertexValue g \$ v

(.||.) :: (a -> a -> Ordering)
-> (a -> a -> Ordering)
-> (a -> a -> Ordering)
f .||. g =
\a b -> case f a b of
EQ -> g a b
r -> r

-- | Node selection comparison function used for triangulating the graph
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Vertex
-> Ordering
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))

{-

Graph triangulation

-}

-- | A cluster containing only the vertices and not yet the factors
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq)

fromVertexCluster (VertexCluster s) = s

instance Show VertexCluster where
show (VertexCluster s) = show . Set.toList \$ s

-- | Triangulate a graph using a cost function
-- The result is the triangulated graph and the list of clusters
-- which may not be maximal.
triangulate :: Graph g
=> (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
-> g () b
-> ([VertexCluster],g () b) -- ^ Returns the clusters and the triangulated graph
triangulate cmp g =
-- At start, gsrc and gdst are the same
-- gsrc is modified. It is where vertex elimination is taking place.
-- The edges are added to gdst
let processAllNodes gsrc gdst l  | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst)
| otherwise =
let selectedNode = minimumBy cmp (allVertices gsrc)
theNeighbors = selectedNode : (fromJust \$ neighbors gsrc selectedNode)
gsrc'' = removeVertex selectedNode gsrc'
in
processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList \$ theNeighbors) : l)

in
processAllNodes g g []

-- | Find for a containing cluster.
findContainingCluster :: VertexCluster -- ^ Cluster processed
-> [VertexCluster] -- ^ Cluster list where to look for a containing cluster
-> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster
findContainingCluster cluster l =
let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
(prefix,suffix) = break clusterIsNotASubsetOf l
in
case suffix of
[] -> (Nothing,l)
_ -> (Just (head suffix),prefix ++ tail suffix)

-- | Remove clusters already contained in a previous clusters
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
where
checkIfMaximal reversedPrefix current [] =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> reverse (current:reversedPrefix)
(Just r,l) -> reverse (r:reverse l)
checkIfMaximal reversedPrefix current suffix =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
(Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)

-- | Create the cluster graph
createClusterGraph :: UndirectedGraph g
=> [VertexCluster]
-> g Int VertexCluster
createClusterGraph c =
let numberedClusters = zip c (map Vertex [0..])
graphWithoutEdges = foldr addCluster emptyGraph numberedClusters
separatorSize ca cb = Set.size \$ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
addClusterEdge ((ca,va),(cb,vb)) g = addEdge (edge va vb) (separatorSize ca cb) g
in

{-

Minimum spanning tree using Prim's algorithm

-}

-- | Tree with values on edges
data Tree b a = Node a [(b,Tree b a)] deriving(Eq)

{-

Implementation of show for the tree

-}
standardHaskellTree :: (Show f, Show b) => Tree b (JTNodeValue f) -> T.Tree String
standardHaskellTree n@(Node a []) = T.Node (show \$ nodeCluster n) []
standardHaskellTree n@(Node a l) = T.Node (show \$ nodeCluster n) (map (standardHaskellTree  . snd) l)

standardVertexTree :: Tree () VertexCluster -> T.Tree String
standardVertexTree n@(Node a []) = T.Node (show a) []
standardVertexTree n@(Node a l) = T.Node (show a) (map (standardVertexTree  . snd) l)

showFactorsAndEdges :: (Show f, Show b) => Tree b (JTNodeValue f) -> (String -> String)
showFactorsAndEdges  n@(Node a []) = (++ show (nodeValueFactor a))
showFactorsAndEdges  n@(Node a l) = foldl1 (.) (map factorAndEdge l) . (++ show (nodeValueFactor a))
where
factorAndEdge (s,t) = showFactorsAndEdges t . (++ show s)

instance (Show f ,Show b)=> Show (Tree b (JTNodeValue f)) where
show t = "JUNCTION TREE\n" ++ T.drawTree (standardHaskellTree t) ++ "\n" ++ showFactorsAndEdges t "" ++ "\n------\n"

instance Show (Tree () VertexCluster) where
show t = "JUNCTION TREE\n" ++ T.drawTree (standardVertexTree t) ++ "\n"

instance Functor.Functor (Tree b) where
fmap f (Node a l) = Node (f a) (map (mapEdge f) l)
where
mapEdge f (e,c) = (e, fmap f c)

-- | Expand a tree (encoded as a list of edges)
-- by adding vertices and keeping track of the vertices which have
-- The selection of where to connect the new vertices is based upon cost of the new edges
expand :: UndirectedGraph g
=> g Int f
-> [Edge] -- ^ List of edges
-> [Vertex] -- ^ Vertices in Tree
-> [Vertex] -- ^ Vertices to add
-> [Edge] -- ^ Updated sets and edge list
expand g theEdges inTree remaining | null remaining = theEdges
| otherwise =
let (treeVertex,outVertex) = maximumBy (compare `on` (edgeCost g)) \$ [(vin,vout) | vin <- inTree, vout <-remaining,isLinkedWithAnEdge g vin vout]
in
expand g (edge treeVertex outVertex : theEdges) (outVertex : inTree)
(filter (/= outVertex) remaining)

where
edgeCost g (va,vb) = fromJust \$ edgeValue g (edge va vb)

leaf x = Node x []
treeEdge c b = (c,b)

-- | Create a tree based upon a description with edges
createTreeFromMap :: Vertex -- ^ Root vertex
-> Map.Map Vertex [Vertex] -- ^ Tree branches
-> Tree () Vertex
createTreeFromMap root m =
let growTree m t@(Node a _) | Map.null m = t
| otherwise =
case Map.lookup a m of
Nothing -> t
Just l -> Node a . map (treeEdge () . growTree m . leaf) \$ l
in
growTree m (leaf root)

-- | Implementing the Prim's algorithm for minimum spanning tree
minimumSpanningTree :: UndirectedGraph g
=> g Int f
-> Tree () f
minimumSpanningTree g =
let startRoot = fromJust \$ someVertex g
remainingVertices = filter (/= startRoot) (allVertices g)
foundEdges = expand g [] [startRoot] remainingVertices
m = Map.fromListWith (++) . map ((\(a,b) -> (a,[b])) . edgeEndPoints) \$ foundEdges
theTree = createTreeFromMap startRoot m
in
Functor.fmap (fromJust . vertexValue g) theTree

{-

Junction tree algorithm

-}

-- | Check if all variables of a factor are included in a cluster
vertexClusterIsContainingFactor :: Factor f => VertexCluster -> f -> Bool
vertexClusterIsContainingFactor c f =
let factorVars = Set.fromList . map variableVertex . factorVariables \$ f
in
Set.isSubsetOf factorVars (fromVertexCluster c)

-- | Check if all variables of a factor are included in a cluster
clusterIsContainingVariable :: DV -> Cluster  -> Bool
clusterIsContainingVariable v c  =
Set.member v (Set.fromList \$ fromCluster c)

-- | Separator which can be in 3 state depending how many messages have passed through it
data Separator f = NoMessage !Cluster
| Collect !Cluster !f
| Distribute !Cluster !f !f -- Upward and downward message
deriving(Eq)

instance Show f => Show (Separator f) where
show (NoMessage c) = "NoMessage: " ++ show c
show (Collect c u) = "Collect: " ++ show c ++ "\n" ++ "\n <----- \n" ++ show u ++ "\n"
show (Distribute c u d) = "Distribute: " ++ show c ++ "\n <----- \n" ++ show u ++ "\n" ++ " -----> \n" ++ show d ++ "\n"

-- | Evidence if some is used for the node
type Evidence f = f

-- | Evidence for cluster, factor for cluster
data JTNodeValue f = JTNodeValue !Cluster !(Evidence f) !f deriving(Eq,Show)

-- | Cluster of discrete variables.
-- Discrete variables instead of vertices are needed because the
-- factor are using 'DV' and we need to find
-- which factors must be contained in a given cluster.
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Show)

fromCluster (Cluster s) = Set.toList s

-- | Convert the clusters from vertex to 'DV' clusters
vertexClusterToCluster :: (Factor f , Graph g)
=> g e f
-> VertexCluster
-> Cluster
vertexClusterToCluster g c =
let vertices = Set.toList . fromVertexCluster \$ c
variables = map factorMainVariable . mapMaybe (vertexValue g) \$ vertices
in
Cluster . Set.fromList \$ variables

-- | Vertices contained in a cluster
clusterVertices :: VertexCluster -> [Vertex]
clusterVertices = Set.toList . fromVertexCluster

-- | Find all factors contained in a cluster
findFactorsForCluster :: (Factor f , Graph g)
=> BayesianNetwork g f
-> VertexCluster
-> [f]
findFactorsForCluster g c =
filter (vertexClusterIsContainingFactor c) . mapMaybe (vertexValue g) . clusterVertices \$ c

-- | The junction tree
type JunctionTree f = Tree (Separator f) (JTNodeValue f)

-- | Get the potential for a cluster
mkNodePotential :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f
-> VertexCluster
-> Set.Set Vertex
-> (JTNodeValue f, Set.Set Vertex)
mkNodePotential g c set =
let -- Factor found in a cluster but they may already be used in another cluster
foundFactors = findFactorsForCluster g c
-- Get the vertices for the factor
vertexForFactors = map (variableVertex . factorMainVariable) foundFactors
-- Keep only the factors which are not already used
isNotUsed (v,f) = Set.member v set
factorsNotYetUsed = filter isNotUsed (zip vertexForFactors foundFactors)
set' = Set.difference set (Set.fromList \$ map fst factorsNotYetUsed)
factorsToUse = map snd factorsNotYetUsed

potential = factorProduct factorsToUse
in
(JTNodeValue (vertexClusterToCluster g c) (factorFromScalar 1.0) potential, set')

-- | Generate the evidence potential for a given cluster
evidenceForCluster :: Factor f => DVISet Int -> Cluster -> Maybe (Evidence f)
evidenceForCluster assignments cluster@(Cluster c) =
let c' = Set.fromList (map instantiationVariable assignments)
common = Set.intersection c' c
selectedVariables = filter (\c -> Set.member (instantiationVariable c) common) assignments
in
evidenceFrom selectedVariables

-- | Get the cluster for a node
nodeCluster :: Tree a (JTNodeValue f) -> Cluster
nodeCluster (Node (JTNodeValue c _ _ ) _) = c

emptyCluster :: Cluster
emptyCluster = Cluster Set.empty

nodeValueFactor (JTNodeValue _ _ f ) = f
nodeValueEvidence (JTNodeValue _ e _) = e

nodeValueWithNewEvidence (JTNodeValue a e b) e' = JTNodeValue a e' b
clearNodeValueEvidence (JTNodeValue a _ b)  = JTNodeValue a (factorFromScalar 1.0) b

-- | Get the cluster for a separator
separatorCluster :: Separator f -> Cluster
separatorCluster (NoMessage c) = c
separatorCluster (Collect c _) = c
separatorCluster (Distribute c _ _) = c

upMessage (Distribute _ u _) = Just u
upMessage (Collect _ u ) = Just u
upMessage _ = Nothing

downMessage (Distribute _ _ d) = Just d
downMessage _ = Nothing

computeSeparatorCluster :: (Factor f, Graph g)
=> BayesianNetwork g f
-> VertexCluster
-> VertexCluster
-> Cluster
computeSeparatorCluster g parent child =
let theNodeCluster (Node c _) = c
childVertices = fromVertexCluster child
parentVertices = fromVertexCluster parent
separatorVertices = VertexCluster \$ Set.intersection childVertices parentVertices
in
vertexClusterToCluster g  separatorVertices

dfs :: (n -> n -> e -> e') -- Parent, child node and their egde
-> (n -> a -> (n', a)) -- Node and current value -> new value and new nod
-> Tree e n  -- Tree to traverse
-> a -- Start value
-> (Tree e' n', a) -- New tree and new value
dfs edgef nodef n@(Node nodevalue []) current =
let (newnodevalue, newval) = nodef nodevalue current
in
(Node newnodevalue [],newval)
dfs edgef nodef n@(Node nodevalue children) current =
let (newnodevalue, newval) = nodef nodevalue current
applyEdgeFunction (e,Node childvalue _) = edgef nodevalue childvalue e
applyToChildren childrenNode val = dfs edgef nodef childrenNode val
edges' = map applyEdgeFunction children
recurseOnChildren s r [] = (s,reverse r)
recurseOnChildren s r (a:l) =
let (a',s') = applyToChildren a s
in
recurseOnChildren s' (a':r) l
(lastval,newSubTrees) = recurseOnChildren newval [] (map snd children)
in
(Node newnodevalue (zip edges' newSubTrees),lastval)

setFactorEdgeUpdate :: (Graph g, Factor f)
=> BayesianNetwork g f
-> VertexCluster
-> VertexCluster
-> ()
-> Separator f
setFactorEdgeUpdate g parentvalue childvalue _ = NoMessage \$ computeSeparatorCluster g parentvalue childvalue

setFactorNodeUpdate :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f
-> VertexCluster
-> Set.Set Vertex
-> (JTNodeValue f, Set.Set Vertex)
setFactorNodeUpdate g nodeValue set = mkNodePotential g nodeValue set

-- | Set a factor for a node
setFactors :: (Graph g, Factor f, Show f)
=> BayesianNetwork g f -- ^ Bayesian graph
-> Tree () VertexCluster  -- ^ Cluster tree with no factors
-> Set.Set Vertex
-> (JunctionTree f,Set.Set Vertex) -- ^ Initialized junction tree
setFactors g = dfs (setFactorEdgeUpdate g) (setFactorNodeUpdate g)

-- | Create a junction tree with only the clusters and no factors
createVerticesJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g)
=> (UndirectedSG () b -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
-> g () b -- ^ Input directed graph
-> Tree () VertexCluster -- ^ Junction tree
createVerticesJunctionTree cmp g =
let theMoralGraph = moralGraph g
(clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph clusters :: UndirectedSG Int VertexCluster
in
minimumSpanningTree g''

-- | Create a function tree
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
-> BayesianNetwork g f -- ^ Input directed graph
-> JunctionTree f -- ^ Junction tree
createJunctionTree cmp g =
let cTree = createVerticesJunctionTree cmp g
factorSet = Set.fromList (allVertices g) -- Tracking of factors which have not yet been put in the junction tree
-- A vertex is linked with a factor so vertex is used as the identifier
(newTree, _) = setFactors g cTree factorSet
in
distribute Nothing . collect \$ newTree

collectMessages :: Factor f => (Separator f , JunctionTree f) -> (Separator f , JunctionTree f)
collectMessages (separator, Node nc []) =
let sc = separatorCluster separator
newPotential = factorProduct [nodeValueFactor nc,nodeValueEvidence nc]
newMessage = factorProjectTo (fromCluster sc) newPotential
in
(Collect sc newMessage, Node nc []) -- Copy node factor to node current potential
collectMessages (separator,(Node nc l)) =
let sc = separatorCluster separator
messagesFromSubTrees = map collectMessages l
newPotential = factorProduct (nodeValueEvidence nc:nodeValueFactor nc:(mapMaybe (upMessage . fst) messagesFromSubTrees))
newMessage = factorProjectTo (fromCluster sc) newPotential
in
(Collect sc newMessage, Node nc messagesFromSubTrees)

-- | Collect phase of the junction tree
collect :: Factor f => JunctionTree f -> JunctionTree f
collect t = let (_,t') = collectMessages (NoMessage emptyCluster, t) in t'

notSameCluster a b = nodeCluster a /= nodeCluster b

-- | Distribute phase of the junction tree
distribute :: Factor f => Maybe (Separator f) -> JunctionTree f -> JunctionTree f
distribute down n@(Node nc []) = n
distribute down (Node nc l) =
let receivedDownMessage = if isJust down then fromJust . downMessage . fromJust \$ down else factorFromScalar 1.0
getUpMessage (edge,c) = upMessage edge
upMessagesForSendingTo i = fromJust . mapM getUpMessage . filter ((i `notSameCluster`) . snd) \$ l
newPotential i = factorProduct (nodeValueFactor nc:nodeValueEvidence nc:receivedDownMessage:upMessagesForSendingTo i)
newMessage sc i = factorProjectTo (fromCluster sc) (newPotential i)
distributeMessage s@(Collect sc dm,i) =
let newSeparator = Distribute sc dm (newMessage sc i)
in
(newSeparator,distribute (Just newSeparator) i)
distributeMessage _ = error "Distribute message can only update a collect phase message"
subTrees = map distributeMessage l
in
Node nc subTrees

-- | Depth first search in  tree
findInTree :: (Tree edge a -> Bool) -> Maybe edge -> Tree edge a -> Maybe (Maybe edge,Tree edge a)
findInTree cmp e n@(Node a []) = if (cmp n) then Just (e,n) else Nothing
findInTree cmp e n@(Node a l) =
let findSome [] = Nothing
findSome ((e',h):t) =
case findInTree cmp (Just e') h of
Nothing -> findSome t
Just r -> Just r
in
case cmp n of
True -> Just (e,n)
False -> findSome l

-- | Compute the marginal posterior (if some evidence is set on the junction tree)
-- otherwise compute just the marginal prior.
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v = do
(maybeEdge,Node n l) <- findInTree (clusterIsContainingVariable v . nodeCluster) Nothing t
let receivedDownMessage = maybe (factorFromScalar 1.0) id \$
do
e <- maybeEdge
downMessage e
upMessages = fromJust . mapM (upMessage . fst) \$ l
p = factorProduct (receivedDownMessage:nodeValueEvidence n:nodeValueFactor n:upMessages)
return \$ normedFactor \$ factorProjectTo [v] p

-- | Apply some evidence modifications in the tree
applyEvidenceWith :: (JunctionTree f -> JunctionTree f) -- ^ Node modification function. Only change node value. Not the children
-> JunctionTree f -- ^ Input tree
-> JunctionTree f
applyEvidenceWith nodeChange n@(Node _ []) = nodeChange n
applyEvidenceWith nodeChange n@(Node _ l) =
let Node n' l' = nodeChange n
changeChildren (e,c) = (e,applyEvidenceWith nodeChange c)
in
Node n' (map changeChildren l')

-- | Change the evidence for a node
evidenceWith :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
evidenceWith assignments t@(Node n l) =
let n' = case evidenceForCluster assignments (nodeCluster t) of
Nothing -> n
Just e' -> nodeValueWithNewEvidence n e'
in
Node n' l

-- | Remove the evidence for a node
clearNodeEvidence (Node n l) = Node (clearNodeValueEvidence n) l

-- | Remove evidence in the junction tree
clearEvidence :: Factor f => JunctionTree f -> JunctionTree f
clearEvidence = distribute Nothing . collect . applyEvidenceWith (clearNodeEvidence)

-- | Update evidence in the tree
updateEvidence :: Factor f => DVISet Int -> JunctionTree f -> JunctionTree f
updateEvidence assignments = distribute Nothing . collect . applyEvidenceWith (evidenceWith assignments)

-- | Used to implement quickcheck.
-- The junction tree property is the property that CA intersection CB is included in all clusters in the path
-- from CA to CB.
junctionTreeProperty :: [VertexCluster] -> Tree () VertexCluster -> Bool
junctionTreeProperty path (Node _ []) = True
junctionTreeProperty path (Node c l) =
let children = map snd l
in
checkPath c (reverse path) && all (junctionTreeProperty (c:path)) children

junctionTreeProperty_prop :: DirectedSG () String -> Property
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let cmp ug = (compare `on` (numberOfAddedEdges ug))
in
junctionTreeProperty [] (createVerticesJunctionTree cmp g)

-- | Check that the intersection of C with any parent in included in any cluster between the parent and C.
checkPath :: VertexCluster -> [VertexCluster] -> Bool
checkPath c l =
let parentSets = map fromVertexCluster l
allIntersections = map (Set.intersection (fromVertexCluster c)) parentSets
pathsToEachParent = tail . inits \$ parentSets
isSubsetOfAllParents i parents = all (Set.isSubsetOf i) parents
in
and \$ zipWith isSubsetOfAllParents allIntersections pathsToEachParent
{-

Moral graph

-}
-- | Get the parents of a vertex
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust \$ ingoing g v >>= mapM (startVertex g)

-- | Get the children of a vertex
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust \$ outgoing g v >>= mapM (endVertex g)

-- | Connect all the nodes which are not connected and apply the function f for each new connection
-- The origin and dest graph must share the same vertex.
connectAllNodesWith :: (Graph g, Graph g')
=> g a b -- ^ Graph containing the nodes
-> g' a b -- ^ Graph to be modified
-> (Edge -> g a b -> g a b) -- ^ Function used to modify the source graph
-> (Edge -> g' a b -> g' a b) -- ^ Function used to modify a new graph
-> [Vertex]  -- ^ List of nodes to connect
-> (g a b,g' a b) -- ^ Result graph
connectAllNodesWith originGraph dstGraph g f nodes  =
let h e (x,y) = (g e x, f e y)
(originGraph',dstGraph') =
foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
in
(originGraph',dstGraph')

addMissingLinks :: DirectedGraph g => Vertex -> b -> g () b -> g () b
let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v)
in
g'

-- | Convert the graph to an undirected form
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
=> g  () b
-> g' () b
convertToUndirected m =
let addVertexWithLabel v dat g =
let theName = fromJust \$ vertexLabel m v
in