{- | Algorithms for factor elimination

-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
module Bayes.FactorElimination(
    -- * Moral graph
      moralGraph
    -- * Triangulation
    , nodeComparisonForTriangulation
    , numberOfAddedEdges
    , weight
    , weightedEdges
    , triangulate
    -- * Junction tree
    , createClusterGraph
    , Cluster
    , createJunctionTree
    , createUninitializedJunctionTree
    , JunctionTree
    , displayTreeValues
    -- * Shenoy-Shafer message passing
    , collect 
    , distribute
    , posterior 
    -- * Evidence
    , changeEvidence
    -- * Test 
    , junctionTreeProperty_prop
    , junctionTreeAllClusters_prop
    , VertexCluster
    -- * For debug 
    , junctionTreeProperty
    , maximumSpanningTree
    , fromVertexCluster
    , triangulatedebug
    ) where

import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM,guard)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl',nub,(\\))
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T 
import Bayes.FactorElimination.JTree
import Control.Applicative((<$>))
import Bayes.VariableElimination(marginal)

import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary

import Bayes.Factor.CPT -- This import is only used for quickcheck tests

--import Debug.Trace
--debug s a = trace (s ++ "\n" ++ show a ++ "\n") a

{-
 
Comparison functions for graph triangulation

-}

-- | Number of edges added when connecting all neighbors
numberOfAddedEdges :: UndirectedGraph g 
                   => g a b 
                   -> Vertex 
                   -> Integer 
numberOfAddedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    fromIntegral $ length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

weightedEdges :: (UndirectedGraph g, Factor f) 
              => g a f 
              -> Vertex 
              -> Integer 
weightedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

-- | Weight of a node
weight :: (UndirectedGraph g, Factor f)
       => g a f 
       -> Vertex 
       -> Integer 
weight g v = 
    fromIntegral $ factorDimension . fromJust . vertexValue g $ v

(.||.) :: (a -> a -> Ordering)
       -> (a -> a -> Ordering) 
       -> (a -> a -> Ordering)
f .||. g = 
    \a b -> case f a b of
              EQ -> g a b 
              r -> r

-- | Node selection comparison function used for triangulating the graph
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
                               => g a f
                               -> Vertex 
                               -> Vertex 
                               -> Ordering 
nodeComparisonForTriangulation g = (compare `on` (weightedEdges g))

{-

Graph triangulation

-}

-- | A cluster containing only the vertices and not yet the factors
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq,Ord)

fromVertexCluster (VertexCluster s) = s

instance Show VertexCluster where 
    show (VertexCluster s) = show . Set.toList $ s

-- | Triangulate a graph using a cost function
-- The result is the triangulated graph and the list of clusters
-- which may not be maximal.
triangulate :: Graph g
            => (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
            -> g () b
            -> [VertexCluster] -- ^ Returns the clusters and the triangulated graph
triangulate cmp gr = removeNodes cmp gr []
 where 
  removeNodes cmp g l | hasNoVertices g = keepMaximalClusters (reverse l)
                      | otherwise = 
                          let selectedNode = minimumBy cmp (allVertices g)
                              theNeighbors = fromJust $ neighbors g selectedNode
                              g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g 
                              newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
                          in 
                          removeNodes cmp g' (newCluster:l)

triangulatedebug :: Graph g
            => (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
            -> g () b
            -> ([VertexCluster],[g () b]) -- ^ Returns the clusters and the triangulated graph
triangulatedebug cmp gr = removeNodes cmp gr [] []
 where 
  removeNodes cmp g l gl | hasNoVertices g = (reverse l,reverse gl)
                         | otherwise = 
                             let selectedNode = minimumBy cmp (allVertices g)
                                 theNeighbors = fromJust $ neighbors g selectedNode
                                 g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g 
                                 newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
                             in 
                             removeNodes cmp g' (newCluster:l) (g:gl)


-- | Find for a containing cluster. 
findContainingCluster :: VertexCluster -- ^ Cluster processed
                      -> [VertexCluster] -- ^ Cluster list where to look for a containing cluster
                      -> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster
findContainingCluster cluster l = 
  let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
       (prefix,suffix) = break clusterIsNotASubsetOf l
  in 
  case suffix of 
    [] -> (Nothing,l)
    _ -> (Just (head suffix),prefix ++ tail suffix)


-- | Remove clusters already contained in a previous clusters
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
 where 
  checkIfMaximal reversedPrefix current [] = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> reverse (current:reversedPrefix) 
      (Just r,l) -> reverse (r:reverse l)
  checkIfMaximal reversedPrefix current suffix = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
      (Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)

-- | Convert the clusters from vertex to 'DV' clusters
vertexClusterToCluster :: (Factor f , Graph g)
                       => g e f 
                       -> VertexCluster 
                       -> Cluster 
vertexClusterToCluster g c = 
  let vertices = Set.toList . fromVertexCluster $ c
      variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
  in 
  Cluster . Set.fromList $ variables


-- | Create the cluster graph
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
                   => g' e f
                   -> [VertexCluster] 
                   -> g Int Cluster
createClusterGraph bn c =
  let numberedClusters = zip c (map Vertex [0..])
      addCluster g (c,v)  = addVertex v (vertexClusterToCluster bn c) g
      graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
      separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
      allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
      addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
  in 
  foldl' addClusterEdge graphWithoutEdges allEdges


{-

Maximum spanning tree using Prim's algorithm
  
-}

-- | Get all possible edges between the leaves and the remaining nodes
possibilities :: (Ord c , UndirectedGraph g) 
              => g Int c -- ^ Original graph to get the edge value 
              -> JTree c f -- ^ Tree to get the vertex for a leaf
              -> [Vertex] -- ^ Vertices to add to the tree
              -> [c] -- ^ List of leaves
              -> [(Vertex,c,Int)] -- ^ Found edge to add
possibilities g currentT remaining leavesClusters = do 
  rv <- remaining
  lv <- leavesClusters
  let NodeValue lvVertex _ _ = nodeValue currentT lv
  guard (isLinkedWithAnEdge g rv lvVertex)
  let ev = fromJust $ edgeValue g (edge rv lvVertex)
  return $ (rv,lv,ev)

-- | Find the max edge to add to the tree
findMax :: (UndirectedGraph g, Ord c, Factor f,Show c)
        => g Int c -- ^ Graph
        -> [Vertex] -- ^ Nodes to add 
        -> JTree c f
        -> ([Vertex],(Vertex,c),c) 
findMax g remaining currentT = 
  let leavesClusters = treeNodes currentT
      edgeValue (_,_,e) = e
      (rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
      remaining' = filter (/= rf) remaining 
      foundCluster = fromJust $ vertexValue g rf
  in
  (remaining', (rf, foundCluster), lf)

-- | Implementing the Prim's algorithm for minimum spanning tree
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c, Show c, Show f) 
                    => g Int c 
                    -> JTree c f
maximumSpanningTree g = 
    let rootNodeVertex = fromJust $ someVertex g 
        rootNodeValue = fromJust $ vertexValue g rootNodeVertex
        startTree = singletonTree rootNodeValue rootNodeVertex [] [] 
        remainingVertices = filter (/= rootNodeVertex) (allVertices g) 
    in 
    buildTree g remainingVertices startTree 

buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c, Show c, Show f)
          => g Int c 
          -> [Vertex]
          -> JTree c f 
          -> JTree c f
buildTree g [] currentT = currentT 
buildTree g l currentT = 
    let (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
        sep = mkSeparator foundElemValue leaf
        newTree = addSeparator leaf sep foundElemValue . 
                  addNode foundElemValue foundElemVertex [] [] $ currentT
    in 
    buildTree g l' newTree
   
{-

Junction tree algorithm

-}


-- | Create a junction tree with only the clusters and no factors
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
                                => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                                -> g () f -- ^ Input directed graph
                                -> JunctionTree f -- ^ Junction tree
createUninitializedJunctionTree cmp g =  
  let theMoralGraph = moralGraph g
      clusters = triangulate (cmp theMoralGraph) theMoralGraph
      g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
  in 
  maximumSpanningTree g''

-- | Create a function tree
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
                  => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                  -> BayesianNetwork g f -- ^ Input directed graph
                  -> JunctionTree f -- ^ Junction tree
createJunctionTree cmp g = 
  let cTree = createUninitializedJunctionTree cmp g 
      -- A vertex is linked with a factor so vertex is used as the identifier
      newTree = setFactors g cTree
  in 
  distribute . collect $ newTree


-- | Compute the marginal posterior (if some evidence is set on the junction tree)
  -- otherwise compute just the marginal prior.
posterior :: (BayesianDiscreteVariable dv, Factor f) => JunctionTree f -> dv -> Maybe f
posterior t someDv = 
  let v = dv someDv
  in
  case snd $ traverseTree (findClusterFor v) Nothing t of 
    Nothing -> Nothing
    Just c -> let NodeValue ver f e = nodeValue t c 
                  d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c)
                  u = map (upMessage t) (nodeChildren t c)
                  allFactors = d:u ++ f ++ e
                  variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ [v]
                  unNormalized = marginal allFactors variablesToRemove [v] []
              in 
              Just $ factorDivide unNormalized (factorNorm unNormalized)

-- | Find a cluster containing the variable
findClusterFor :: DV 
               -> Maybe Cluster
               -> Cluster -- ^ Current cluster
               -> NodeValue f -- ^ Current value
               -> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v = 
  case Set.member dv sc of 
    False -> Skip s 
    True -> Stop (Just c)


junctionTreeProperty_prop :: DirectedSG () CPT -> Property 
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
  let cmp ug = (compare `on` (numberOfAddedEdges ug))
      t = createUninitializedJunctionTree cmp g
  in
  junctionTreeProperty t [] (root t)

junctionTreeAllClusters_prop :: DirectedSG () CPT -> Property 
junctionTreeAllClusters_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
      let theMoralGraph = moralGraph g
          cmp ug = (compare `on` (numberOfAddedEdges ug))
          clusters = triangulate (cmp theMoralGraph) theMoralGraph
          g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
          jt = maximumSpanningTree g'' :: JunctionTree CPT
          treeClusters = treeNodes jt 
          sa = Set.fromList (map (vertexClusterToCluster g) clusters) 
          sb = Set.fromList treeClusters 
      in 
      Set.isSubsetOf sa sb && Set.isSubsetOf sb sa

junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c = 
  let cl = map (separatorChild t) . nodeChildren t $ c
  in
  checkPath c path && all (junctionTreeProperty t (c:path)) cl 


-- | Check that the intersection of C with any parent in included in all cluster between the parent and C.
checkPath :: Cluster -> [Cluster] -> Bool 
checkPath _ [] = True
checkPath (Cluster c) l = 
  let clusterSet (Cluster s) = s -- x
      parentSets = map clusterSet l -- Example a b c d where a is the root
      allIntersectionsWithParents = map (Set.intersection c) parentSets -- a ^ x, b ^ x , c ^ x , d ^ x
      pathsToEachParent = tail . inits $ parentSets -- a, ab, abc, abcd
      isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
  in    
  and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
{-

Moral graph

-}
-- | Get the parents of a vertex
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g) 

-- | Get the children of a vertex
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g) 

-- | Connect all the nodes which are not connected and apply the function f for each new connection
-- The origin and dest graph must share the same vertex.
connectAllNonAdjacentNodes :: (Graph g) 
                           => [Vertex]  -- ^ List of nodes to connect
                           -> g () b -- ^ Graph containing the nodes
                           -> g () b
connectAllNonAdjacentNodes nodes originGraph   =  
    let addEmptyEdge g e = addEdge e () g
    in 
    foldl' addEmptyEdge originGraph [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
   
-- | Add the missing parent links
addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b
addMissingLinks g v _ = connectAllNonAdjacentNodes (parents g v) g 


-- | Convert the graph to an undirected form
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
                    => g  () b 
                    -> g' () b 
convertToUndirected m = 
    let addVertexWithLabel g v dat  = 
           let theName = fromJust $ vertexLabel m v
           in 
           addLabeledVertex theName v dat g
        newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m
        addEmptyEdge edge g = addEdge edge () g
    in 
    foldr addEmptyEdge newDiscreteGraph . allEdges $ m

-- | For the junction tree construction, only the vertices are needed during the intermediate steps.
-- So, the moral graph is returned without any vertex data.
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g) 
           => g () b -> UndirectedSG () b 
moralGraph g = 
    convertToUndirected  . foldlWithVertex' addMissingLinks g $ g