module Bayes.FactorElimination(
    
      moralGraph
    
    , nodeComparisonForTriangulation
    , numberOfAddedEdges
    , weight
    , weightedEdges
    , triangulate
    
    , createClusterGraph
    , Cluster
    , createJunctionTree
    , createUninitializedJunctionTree
    , JunctionTree
    , displayTreeValues
    
    , collect 
    , distribute
    , posterior 
    
    , changeEvidence
    
    , junctionTreeProperty_prop
    , junctionTreeAllClusters_prop
    , VertexCluster
    
    , junctionTreeProperty
    , maximumSpanningTree
    , fromVertexCluster
    , triangulatedebug
    ) where
import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM,guard)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl',nub,(\\))
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T 
import Bayes.FactorElimination.JTree
import Control.Applicative((<$>))
import Bayes.VariableElimination(marginal)
import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary
import Bayes.VariableElimination.Buckets(IsBucketItem(..))
import Bayes.Factor.CPT 
numberOfAddedEdges :: UndirectedGraph g 
                   => g a b 
                   -> Vertex 
                   -> Integer 
numberOfAddedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    fromIntegral $ length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weightedEdges :: (UndirectedGraph g, Factor f) 
              => g a f 
              -> Vertex 
              -> Integer 
weightedEdges g v = 
    let nodes = fromJust $ neighbors g v
    in 
    sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weight :: (UndirectedGraph g, Factor f)
       => g a f 
       -> Vertex 
       -> Integer 
weight g v = 
    fromIntegral $ factorDimension . fromJust . vertexValue g $ v
(.||.) :: (a -> a -> Ordering)
       -> (a -> a -> Ordering) 
       -> (a -> a -> Ordering)
f .||. g = 
    \a b -> case f a b of
              EQ -> g a b 
              r -> r
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
                               => g a f
                               -> Vertex 
                               -> Vertex 
                               -> Ordering 
nodeComparisonForTriangulation g = (compare `on` (weightedEdges g))
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq,Ord)
fromVertexCluster (VertexCluster s) = s
instance Show VertexCluster where 
    show (VertexCluster s) = show . Set.toList $ s
triangulate :: Graph g
            => (Vertex -> Vertex -> Ordering) 
            -> g () b
            -> [VertexCluster] 
triangulate cmp gr = removeNodes cmp gr []
 where 
  removeNodes cmp g l | hasNoVertices g = keepMaximalClusters (reverse l)
                      | otherwise = 
                          let selectedNode = minimumBy cmp (allVertices g)
                              theNeighbors = fromJust $ neighbors g selectedNode
                              g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g 
                              newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
                          in 
                          removeNodes cmp g' (newCluster:l)
triangulatedebug :: Graph g
            => (Vertex -> Vertex -> Ordering) 
            -> g () b
            -> ([VertexCluster],[g () b]) 
triangulatedebug cmp gr = removeNodes cmp gr [] []
 where 
  removeNodes cmp g l gl | hasNoVertices g = (reverse l,reverse gl)
                         | otherwise = 
                             let selectedNode = minimumBy cmp (allVertices g)
                                 theNeighbors = fromJust $ neighbors g selectedNode
                                 g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g 
                                 newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
                             in 
                             removeNodes cmp g' (newCluster:l) (g:gl)
findContainingCluster :: VertexCluster 
                      -> [VertexCluster] 
                      -> (Maybe VertexCluster,[VertexCluster]) 
findContainingCluster cluster l = 
  let  clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
       (prefix,suffix) = break clusterIsNotASubsetOf l
  in 
  case suffix of 
    [] -> (Nothing,l)
    _ -> (Just (head suffix),prefix ++ tail suffix)
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
 where 
  checkIfMaximal reversedPrefix current [] = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> reverse (current:reversedPrefix) 
      (Just r,l) -> reverse (r:reverse l)
  checkIfMaximal reversedPrefix current suffix = 
    case findContainingCluster current (reverse reversedPrefix) of 
      (Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
      (Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)
vertexClusterToCluster :: (Factor f , Graph g)
                       => g e f 
                       -> VertexCluster 
                       -> Cluster 
vertexClusterToCluster g c = 
  let vertices = Set.toList . fromVertexCluster $ c
      variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
  in 
  Cluster . Set.fromList $ variables
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
                   => g' e f
                   -> [VertexCluster] 
                   -> g Int Cluster
createClusterGraph bn c =
  let numberedClusters = zip c (map Vertex [0..])
      addCluster g (c,v)  = addVertex v (vertexClusterToCluster bn c) g
      graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
      separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
      allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
      addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
  in 
  foldl' addClusterEdge graphWithoutEdges allEdges
possibilities :: (Ord c , UndirectedGraph g) 
              => g Int c 
              -> JTree c f 
              -> [Vertex] 
              -> [c] 
              -> [(Vertex,c,Int)] 
possibilities g currentT remaining leavesClusters = do 
  rv <- remaining
  lv <- leavesClusters
  let NodeValue lvVertex _ _ = nodeValue currentT lv
  guard (isLinkedWithAnEdge g rv lvVertex)
  let ev = fromJust $ edgeValue g (edge rv lvVertex)
  return $ (rv,lv,ev)
findMax :: (UndirectedGraph g, Ord c, Factor f,Show c)
        => g Int c 
        -> [Vertex] 
        -> JTree c f
        -> ([Vertex],(Vertex,c),c) 
findMax g remaining currentT = 
  let leavesClusters = treeNodes currentT
      edgeValue (_,_,e) = e
      (rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
      remaining' = filter (/= rf) remaining 
      foundCluster = fromJust $ vertexValue g rf
  in
  (remaining', (rf, foundCluster), lf)
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c, Show c, Show f) 
                    => g Int c 
                    -> JTree c f
maximumSpanningTree g = 
    let rootNodeVertex = fromJust $ someVertex g 
        rootNodeValue = fromJust $ vertexValue g rootNodeVertex
        startTree = singletonTree rootNodeValue rootNodeVertex [] [] 
        remainingVertices = filter (/= rootNodeVertex) (allVertices g) 
    in 
    buildTree g remainingVertices startTree 
buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c, Show c, Show f)
          => g Int c 
          -> [Vertex]
          -> JTree c f 
          -> JTree c f
buildTree g [] currentT = currentT 
buildTree g l currentT = 
    let (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
        sep = mkSeparator foundElemValue leaf
        newTree = addSeparator leaf sep foundElemValue . 
                  addNode foundElemValue foundElemVertex [] [] $ currentT
    in 
    buildTree g l' newTree
   
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
                                => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) 
                                -> g () f 
                                -> JunctionTree f 
createUninitializedJunctionTree cmp g =  
  let theMoralGraph = moralGraph g
      clusters = triangulate (cmp theMoralGraph) theMoralGraph
      g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
  in 
  maximumSpanningTree g''
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, IsBucketItem f, Show f)
                  => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) 
                  -> BayesianNetwork g f 
                  -> JunctionTree f 
createJunctionTree cmp g = 
  let cTree = createUninitializedJunctionTree cmp g 
      
      newTree = setFactors g cTree
  in 
  distribute . collect $ newTree
  
posterior :: (BayesianDiscreteVariable dv, Factor f, IsBucketItem f) => JunctionTree f -> dv -> Maybe f
posterior t someDv = 
  let v = dv someDv
  in
  case snd $ traverseTree (findClusterFor v) Nothing t of 
    Nothing -> Nothing
    Just c -> let NodeValue ver f e = nodeValue t c 
                  d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c)
                  u = map (upMessage t) (nodeChildren t c)
                  allFactors = d:u ++ f ++ e
                  variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ [v]
                  unNormalized = marginal allFactors variablesToRemove [v] []
              in 
              Just $ factorDivide unNormalized (factorNorm unNormalized)
findClusterFor :: DV 
               -> Maybe Cluster
               -> Cluster 
               -> NodeValue f 
               -> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v = 
  case Set.member dv sc of 
    False -> Skip s 
    True -> Stop (Just c)
junctionTreeProperty_prop :: DirectedSG () CPT -> Property 
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
  let cmp ug = (compare `on` (numberOfAddedEdges ug))
      t = createUninitializedJunctionTree cmp g
  in
  junctionTreeProperty t [] (root t)
junctionTreeAllClusters_prop :: DirectedSG () CPT -> Property 
junctionTreeAllClusters_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> 
      let theMoralGraph = moralGraph g
          cmp ug = (compare `on` (numberOfAddedEdges ug))
          clusters = triangulate (cmp theMoralGraph) theMoralGraph
          g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
          jt = maximumSpanningTree g'' :: JunctionTree CPT
          treeClusters = treeNodes jt 
          sa = Set.fromList (map (vertexClusterToCluster g) clusters) 
          sb = Set.fromList treeClusters 
      in 
      Set.isSubsetOf sa sb && Set.isSubsetOf sb sa
junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c = 
  let cl = map (separatorChild t) . nodeChildren t $ c
  in
  checkPath c path && all (junctionTreeProperty t (c:path)) cl 
checkPath :: Cluster -> [Cluster] -> Bool 
checkPath _ [] = True
checkPath (Cluster c) l = 
  let clusterSet (Cluster s) = s 
      parentSets = map clusterSet l 
      allIntersectionsWithParents = map (Set.intersection c) parentSets 
      pathsToEachParent = tail . inits $ parentSets 
      isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
  in    
  and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g) 
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g) 
connectAllNonAdjacentNodes :: (Graph g) 
                           => [Vertex]  
                           -> g () b 
                           -> g () b
connectAllNonAdjacentNodes nodes originGraph   =  
    let addEmptyEdge g e = addEdge e () g
    in 
    foldl' addEmptyEdge originGraph [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
   
addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b
addMissingLinks g v _ = connectAllNonAdjacentNodes (parents g v) g 
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
                    => g  () b 
                    -> g' () b 
convertToUndirected m = 
    let addVertexWithLabel g v dat  = 
           let theName = fromJust $ vertexLabel m v
           in 
           addLabeledVertex theName v dat g
        newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m
        addEmptyEdge edge g = addEdge edge () g
    in 
    foldr addEmptyEdge newDiscreteGraph . allEdges $ m
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g) 
           => g () b -> UndirectedSG () b 
moralGraph g = 
    convertToUndirected  . foldlWithVertex' addMissingLinks g $ g