{- | Algorithms for variable elimination

-}
module Bayes.VariableElimination(
 -- * Inferences
   priorMarginal
 , posteriorMarginal
 -- * Interaction graph and elimination order
 , interactionGraph
 , degreeOrder
 , minDegreeOrder
 , minFillOrder
 , allVariables
 , EliminationOrder
 ) where

import Bayes
import Bayes.Factor
import Data.List(partition,minimumBy,(\\),find)
import Data.Maybe(fromJust)
import Data.Function(on)
import qualified Data.Map as M

--import Debug.Trace 

--debug s a = trace (s  ++ "\n" ++ show a ++ "\n") a

-- | Elimination order
type EliminationOrder = [DV]

-- | Get all variables from a Bayesian Network
allVariables :: (Graph g, Factor f) 
             => BayesianNetwork g f 
             -> [DV]
allVariables g = 
  let s = allVertexValues g 
      createDV = factorMainVariable 
  in 
  map createDV s

-- | Used for bucket elimination. Factor are organized by their first DV
type Buckets f = (EliminationOrder,M.Map DV [f])

createBuckets ::  (Graph g, Factor f, Show f) 
              => BayesianNetwork g f -- ^ Bayesian Network
              -> EliminationOrder -- ^ Variables to eliminate
              -> EliminationOrder -- ^ Remaining variables
              -> Buckets f 
createBuckets g e r = 
  let s = allVertexValues g
      -- We put the selected variables for elimination in the right order at the beginning
      -- Which means the function can work with a partial order which is completed with other
      -- variables by default.
      theOrder = e ++ r
      addDVToBucket dv (rf, m) =
        let (fk,remaining) = partition (flip containsVariable dv) rf
        in 
        (remaining, M.insert dv fk m)
      (_,b) = foldr addDVToBucket (s,M.empty) (reverse theOrder)
  in
  (tail theOrder,b)

-- | Get the factors for a bucket
getBucket :: DV 
          -> Buckets f 
          -> [f]
getBucket dv (_,m) = fromJust $ M.lookup dv m

-- | Update bucket
updateBucket :: Factor f => DV -> f -> Buckets f -> Buckets f 
updateBucket dv f b@(e,m) = 
  if isScalarFactor f 
    then 
      (tail e,M.insert dv [f] m)
    else
      let b' = removeFromBucket dv b
          (e',m') = addBucket f b'
      in 
      (tail e',m')

-- | Add a factor to the right bucket
addBucket :: Factor f => f -> Buckets f -> Buckets f
addBucket f (e,b) = 
  let inBucket = find (f `containsVariable`) e
  in 
  case inBucket of 
    Nothing -> (e,b)
    Just bucket -> (e, M.insertWith' (++) bucket [f] b)

-- | Remove a variable from the bucket
removeFromBucket :: DV -> Buckets f -> Buckets f 
removeFromBucket dv (e,m) = (e,M.delete dv m) 

-- | Compute the prior marginal. All the variables in the
-- elimination order are conditionning variables ( p( . | conditionning variables) )
posteriorMarginal :: (Graph g, Factor f, Show f) 
                  => BayesianNetwork g f -- ^ Bayesian Network
                  -> EliminationOrder -- ^ Ordering of variables to marginzalie
                  -> EliminationOrder -- ^ Ordering of remaining variables
                  -> [DVI Int] -- ^ Assignment for some factors in vaiables to marginalize
                  -> f
posteriorMarginal n p r assignment = 
  -- The elimintation order are the variables to eliminate.
  -- But the algorithm also needs the remaining variables
  let bucket = createBuckets n p r
      assignmentFactors = map factorFromInstantiation assignment
      bucket' = foldr addBucket bucket assignmentFactors
      (_,resultBucket) = foldr marginalizeOneVariable bucket' (reverse p)
      resultFactor = factorProduct . concat . M.elems $ resultBucket
      -- The norm is P(e) and result factor is P(Q,e)
      norm = factorNorm resultFactor
  in
  -- We get P(Q | e)
  resultFactor `factorDivide` norm 
 where 
  marginalizeOneVariable dv currentBucket = 
    let fk = getBucket dv currentBucket
        p = factorProduct fk
        f' = factorProjectOut [dv] p
    in
    updateBucket dv f' currentBucket

-- | Compute the prior marginal. All the variables in the
-- elimination order are conditionning variables ( p( . | conditionning variables) )
priorMarginal :: (Graph g, Factor f, Show f) 
              => BayesianNetwork g f -- ^ Bayesian Network
              -> EliminationOrder -- ^ Ordering of variables to marginalize
              -> EliminationOrder -- ^ Ordering of remaining to keep in result
              -> f
priorMarginal g ea eb = posteriorMarginal g ea eb []

-- | Compute the interaction graph of the BayesianNetwork
interactionGraph :: (FoldableWithVertex g,Factor f, UndirectedGraph g')
                 => BayesianNetwork g f
                 -> g' () DV
interactionGraph g = 
  foldrWithVertex addFactor emptyGraph g 
 where
  addFactor vertex factor graph = 
    let allvars = factorVariables factor
        edges = [(x,y) | x <- allvars, y <- allvars , x /= y]
        addNewEdge (va,vb) g = 
          let g' = addVertex (variableVertex vb) vb . addVertex (variableVertex va) va $ g 
          in
          addEdge (edge (variableVertex va) (variableVertex vb)) () $ g'
    in 
    foldr addNewEdge graph edges

-- | Number of neighbors for a variable in the bayesian network
nbNeighbors :: UndirectedSG () DV 
            -> DV 
            -> Int 
nbNeighbors g dv = 
  let r = fromJust $ neighbors g (variableVertex dv)
  in 
  length r

-- | Number of missing links between the neighbors of the graph
nbMissingLinks :: UndirectedSG () DV  
               -> DV 
               -> Int 
nbMissingLinks g dv = 
  let r = fromJust $ neighbors g (variableVertex dv)
      edges = [(x,y) | x <- r, y <- r , x /= y, not (isLinkedWithAnEdge g x y)]
  in 
  length edges

-- | Compute the degree order of an elimination order
degreeOrder :: (FoldableWithVertex g, Factor f, Graph g)
            => BayesianNetwork g f
            -> EliminationOrder 
            -> Int 
degreeOrder g p =
  let  ig = interactionGraph g :: UndirectedSG () DV
       (_,w) = foldr processVariable (ig,0) p 
  in 
  w 
 where 
  addAnEdge (va,vb) g = addEdge (edge va vb) () g
  processVariable bdv (g,w) = 
    let r = fromJust $ neighbors g (variableVertex bdv)
        nbNeighbors = length r
        edges = [(x,y) | x <- r, y <- r , x /= y, not (isLinkedWithAnEdge g x y)]
        g' = removeVertex (variableVertex bdv) (foldr addAnEdge g edges)
    in
    if nbNeighbors > w 
      then 
        (g',nbNeighbors) 
      else 
        (g',w)
 
-- | Find an elimination order minimizing a metric
eliminationOrderForMetric :: (Graph g, Factor f, FoldableWithVertex g, UndirectedGraph g')
                          => (g' () DV -> DV -> Int)
                          -> BayesianNetwork g f 
                          -> EliminationOrder 
eliminationOrderForMetric metric g = 
  let ig = interactionGraph g
      s = allVertexValues ig
      getOptimalNode _ [] = []
      getOptimalNode g l = 
        let (optimalNode,_) = minimumBy (compare `on` snd) . map (\v -> (v,metric g v)) $ l
            g' = removeVertex (variableVertex optimalNode) g
        in 
        optimalNode : getOptimalNode g' (l \\ [optimalNode])
  in 
    getOptimalNode ig s

-- | Elimination order minimizing the degree
minDegreeOrder :: (Graph g, Factor f, FoldableWithVertex g)
               => BayesianNetwork g f 
               -> EliminationOrder 
minDegreeOrder = eliminationOrderForMetric nbNeighbors

-- | Elimination order minimizing the filling
minFillOrder :: (Graph g, Factor f, FoldableWithVertex g)
               => BayesianNetwork g f 
               -> EliminationOrder 
minFillOrder = eliminationOrderForMetric nbMissingLinks