{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Algorithms for factor elimination

-}
module Bayes.FactorElimination(
    -- * Moral graph
      moralGraph
    -- * Triangulation
    , nodeComparisonForTriangulation
    , numberOfAddedEdges
    , weight
    , weightedEdges
    , triangulate
    -- * Junction tree
    , createClusterGraph
    , Cluster
    , createJunctionTree
    , createUninitializedJunctionTree
    , JunctionTree
    -- * Shenoy-Shafer message passing
    , collect 
    , distribute
    , posterior 
    -- * Evidence
    , changeEvidence
    -- * Test 
    , junctionTreeProperty_prop
    , VertexCluster
    -- * For debug 
    , junctionTreeProperty
    , maximumSpanningTree
    , fromVertexCluster
    ) where

import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM,guard)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl')
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T 
import Bayes.FactorElimination.JTree
import Control.Applicative((<$>))

import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary

--import Debug.Trace
--debug s a = trace (s ++ " " ++ show a ++ "\n") a

{-
 
Comparison functions for graph triangulation

-}

-- | Number of edges added when connecting all neighbors
numberOfAddedEdges :: UndirectedGraph g 
                   => g a b 
                   -> Vertex 
                   -> Int 
numberOfAddedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

weightedEdges :: (UndirectedGraph g, Factor f) 
              => g a f 
              -> Vertex 
              -> Int 
weightedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]

-- | Weight of a node
weight :: (UndirectedGraph g, Factor f)
       => g a f 
       -> Vertex 
       -> Int 
weight g v = 
    factorDimension . fromJust . vertexValue g $ v

(.||.) :: (a -> a -> Ordering)
       -> (a -> a -> Ordering) 
       -> (a -> a -> Ordering)
f .||. g = 
    \a b -> case f a b of
              EQ -> g a b 
              r -> r

-- | Node selection comparison function used for triangulating the graph
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
                               => g a f
                               -> Vertex 
                               -> Vertex 
                               -> Ordering 
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))

{-

Graph triangulation

-}

-- | A cluster containing only the vertices and not yet the factors
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq)

fromVertexCluster (VertexCluster s) = s

instance Show VertexCluster where 
    show (VertexCluster s) = show . Set.toList $ s



instance IsCluster Cluster where 
  overlappingEvidence (Cluster c) e = filter (\x -> Set.member (instantiationVariable x) c) e
  clusterVariables (Cluster s) = Set.toList s
  mkSeparator (Cluster sa) (Cluster sb) = Cluster $ Set.intersection sa sb

-- | Triangulate a graph using a cost function
-- The result is the triangulated graph and the list of clusters
-- which may not be maximal.
triangulate :: Graph g
            => (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation
            -> g () b
            -> ([VertexCluster],g () b) -- ^ Returns the clusters and the triangulated graph
triangulate cmp g = 
    -- At start, gsrc and gdst are the same
    -- gsrc is modified. It is where vertex elimination is taking place.
    -- The edges are added to gdst
    let processAllNodes gsrc gdst l  | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst)
                                     | otherwise = 
                                            let selectedNode = minimumBy cmp (allVertices gsrc)
                                                theNeighbors = selectedNode : (fromJust $ neighbors gsrc selectedNode)
                                                addEmptyEdge e g = addEdge e () g
                                                (gsrc',gdst') = connectAllNodesWith gsrc gdst addEmptyEdge addEmptyEdge theNeighbors
                                                gsrc'' = removeVertex selectedNode gsrc' 
                                            in 
                                            processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList $ theNeighbors) : l)

    in 
    processAllNodes g g []


-- | Find for a containing cluster. 
findContainingCluster :: VertexCluster -- ^ Cluster processed
                      -> [VertexCluster] -- ^ Cluster list where to look for a containing cluster
                      -> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster
findContainingCluster cluster l = 
  let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
       (prefix,suffix) = break clusterIsNotASubsetOf l
  in 
  case suffix of 
    [] -> (Nothing,l)
    _ -> (Just (head suffix),prefix ++ tail suffix)


-- | Remove clusters already contained in a previous clusters
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
 where 
  checkIfMaximal reversedPrefix current [] = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> reverse (current:reversedPrefix) 
      (Just r,l) -> reverse (r:reverse l)
  checkIfMaximal reversedPrefix current suffix = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
      (Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)

-- | Convert the clusters from vertex to 'DV' clusters
vertexClusterToCluster :: (Factor f , Graph g)
                       => g e f 
                       -> VertexCluster 
                       -> Cluster 
vertexClusterToCluster g c = 
  let vertices = Set.toList . fromVertexCluster $ c
      variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
  in 
  Cluster . Set.fromList $ variables


-- | Create the cluster graph
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
                   => g' e f
                   -> [VertexCluster] 
                   -> g Int Cluster
createClusterGraph bn c =
  let numberedClusters = zip c (map Vertex [0..])
      addCluster g (c,v)  = addVertex v (vertexClusterToCluster bn c) g
      graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
      separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
      allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
      addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
  in 
  foldl' addClusterEdge graphWithoutEdges allEdges


{-

Maximum spanning tree using Prim's algorithm
  
-}

-- | Get all possible edges between the leaves and the remaining nodes
possibilities :: (Ord c , UndirectedGraph g) 
              => g Int c -- ^ Original graph to get the edge value 
              -> JTree c (Vertex,f) -- ^ Tree to get the vertex for a leaf
              -> [Vertex] -- ^ Vertices to add to the tree
              -> [c] -- ^ List of leaves
              -> [(Vertex,c,Int)] -- ^ Found edge to add
possibilities g currentT remaining leavesClusters = do 
  rv <- remaining
  lv <- leavesClusters
  let NodeValue (lvVertex,lvCluster) _ = nodeValue currentT lv
  guard (isLinkedWithAnEdge g rv lvVertex)
  let ev = fromJust $ edgeValue g (edge rv lvVertex)
  return $ (rv,lv,ev)

-- | Find the max edge to add to the tree
findMax :: (UndirectedGraph g, Ord c, Factor f)
        => g Int c -- ^ Graph
        -> [Vertex] -- ^ Nodes to add 
        -> JTree c (Vertex,f)
        -> ([Vertex],(Vertex,c),c) 
findMax g remaining currentT = 
  let leavesClusters = treeNodes currentT
      edgeValue (_,_,e) = e
      (rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
      remaining' = filter (/= rf) remaining 
      foundCluster = fromJust $ vertexValue g rf
  in
  (remaining', (rf, foundCluster), lf)

removeVertices :: JTree c (Vertex,f) -> JTree c f
removeVertices t = t { nodeValueMap = Map.map removeVertexFromNode (nodeValueMap t)
                     , separatorValueMap = Map.map removeVertexFromSeparator (separatorValueMap t)
                     }
 where 
  removeVertexFromNode (NodeValue (_,f) (_,e)) = NodeValue f e 
  removeVertexFromSeparator (SeparatorValue (_,u) (Just (_,d))) = SeparatorValue u (Just d)
  removeVertexFromSeparator (SeparatorValue (_,u) Nothing) = SeparatorValue u Nothing 
  removeVertexFromSeparator EmptySeparator = EmptySeparator

-- | Implementing the Prim's algorithm for minimum spanning tree
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c) 
                    => g Int c 
                    -> JTree c f
maximumSpanningTree g = 
    let rootNodeVertex = fromJust $ someVertex g 
        rootNodeValue = fromJust $ vertexValue g rootNodeVertex
        unitFactor = factorFromScalar 1.0 
        startTree = singletonTree rootNodeValue (rootNodeVertex,unitFactor) (rootNodeVertex,unitFactor) 
        remainingVertices = filter (/= rootNodeVertex) (allVertices g) 
    in 
    removeVertices $ buildTree g remainingVertices startTree 

buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c)
          => g Int c 
          -> [Vertex]
          -> JTree c (Vertex,f) 
          -> JTree c (Vertex,f) 
buildTree g [] currentT = currentT 
buildTree g l@(h:t) currentT = 
    let unitFactor = factorFromScalar 1.0
        (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
        sep = mkSeparator foundElemValue leaf
        newTree = addSeparator leaf sep foundElemValue . 
                  addNode foundElemValue (foundElemVertex,unitFactor) (foundElemVertex,unitFactor) $ currentT
    in 
    buildTree g l' newTree
   
{-

Junction tree algorithm

-}


-- | Create a junction tree with only the clusters and no factors
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f)
                                => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                                -> g () f -- ^ Input directed graph
                                -> JunctionTree f -- ^ Junction tree
createUninitializedJunctionTree cmp g =  
  let theMoralGraph = moralGraph g
      (clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph
      g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
  in 
  maximumSpanningTree g''

-- | Create a function tree
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
                  => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph
                  -> BayesianNetwork g f -- ^ Input directed graph
                  -> JunctionTree f -- ^ Junction tree
createJunctionTree cmp g = 
  let cTree = createUninitializedJunctionTree cmp g 
      -- A vertex is linked with a factor so vertex is used as the identifier
      newTree = setFactors g cTree
  in 
  distribute . collect $ newTree


-- | Compute the marginal posterior (if some evidence is set on the junction tree)
  -- otherwise compute just the marginal prior.
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v = 
  case snd $ traverseTree (findClusterFor v) Nothing t of 
    Nothing -> Nothing
    Just c -> let NodeValue f e = nodeValue t c 
                  d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c)
                  u = map (upMessage t) (nodeChildren t c)
                  unNormalized = factorProjectTo [v] (factorProduct (f:e:d:u))
              in 
              Just $ factorDivide unNormalized (factorNorm unNormalized)

-- | Find a cluster containing the variable
findClusterFor :: DV 
               -> Maybe Cluster
               -> Cluster -- ^ Current cluster
               -> NodeValue f -- ^ Current value
               -> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v = 
  case Set.member dv sc of 
    False -> Skip s 
    True -> Stop (Just c)


junctionTreeProperty_prop :: DirectedSG () CPT -> Property 
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
  let cmp ug = (compare `on` (numberOfAddedEdges ug))
      t = createUninitializedJunctionTree cmp g
  in
  junctionTreeProperty t [] (root t)

junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c = 
  let cl = map (separatorChild t) . nodeChildren t $ c
  in
  checkPath c path && all (junctionTreeProperty t (c:path)) cl 


-- | Check that the intersection of C with any parent in included in all cluster between the parent and C.
checkPath :: Cluster -> [Cluster] -> Bool 
checkPath _ [] = True
checkPath (Cluster c) l = 
  let clusterSet (Cluster s) = s -- x
      parentSets = map clusterSet l -- Example a b c d where a is the root
      allIntersectionsWithParents = map (Set.intersection c) parentSets -- a ^ x, b ^ x , c ^ x , d ^ x
      pathsToEachParent = tail . inits $ parentSets -- a, ab, abc, abcd
      isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
  in    
  and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
{-

Moral graph

-}
-- | Get the parents of a vertex
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g) 

-- | Get the children of a vertex
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g) 

-- | Connect all the nodes which are not connected and apply the function f for each new connection
-- The origin and dest graph must share the same vertex.
connectAllNodesWith :: (Graph g, Graph g') 
                    => g a b -- ^ Graph containing the nodes
                    -> g' a b -- ^ Graph to be modified
                    -> (Edge -> g a b -> g a b) -- ^ Function used to modify the source graph
                    -> (Edge -> g' a b -> g' a b) -- ^ Function used to modify a new graph
                    -> [Vertex]  -- ^ List of nodes to connect
                    -> (g a b,g' a b) -- ^ Result graph
connectAllNodesWith originGraph dstGraph g f nodes  =  
    let h e (x,y) = (g e x, f e y)
        (originGraph',dstGraph') = 
           foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
    in 
    (originGraph',dstGraph')

-- | Add the missing parent links
addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b
addMissingLinks g v _ = 
    let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v)
    in 
    g'


-- | Convert the graph to an undirected form
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
                    => g  () b 
                    -> g' () b 
convertToUndirected m = 
    let addVertexWithLabel g v dat  = 
           let theName = fromJust $ vertexLabel m v
           in 
           addLabeledVertex theName v dat g
        newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m
        addEmptyEdge edge g = addEdge edge () g
    in 
    foldr addEmptyEdge newDiscreteGraph . allEdges $ m

-- | For the junction tree construction, only the vertices are needed during the intermediate steps.
-- So, the moral graph is returned without any vertex data.
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g) 
           => g () b -> UndirectedSG () b 
moralGraph g = 
    convertToUndirected  . foldlWithVertex' addMissingLinks g $ g