{- | Junction Trees 

The Tree data structures are not working very well with message passing algorithms. So, junction trees are using
a different representation

-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Bayes.FactorElimination.JTree(
      IsCluster(..)
    , Cluster(..)
    , JTree(..)
    , JunctionTree(..)
    , Sep
    , setFactors
    , distribute 
    , collect
    , fromCluster
    , changeEvidence
    , nodeIsMemberOfTree
    , singletonTree
    , addNode 
    , addSeparator
    , leaves
    , nodeValue
    , NodeValue(..)
    , SeparatorValue(..)
    , downMessage
    , upMessage 
    , nodeParent 
    , nodeChildren
    , traverseTree
    , separatorChild
    , treeNodes
    , treeValues
    , displayTreeValues
    , Action(..)
    ) where 

import qualified Data.Map as Map
import qualified Data.Tree as Tree
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.Set as Set
import Data.Monoid
import Data.List((\\), intersect,partition, foldl',minimumBy,nub)
import Bayes.PrivateTypes 
import Bayes.Factor
import Bayes
import Data.Function(on)
import Bayes.VariableElimination(marginal)
import Data.Binary
import Bayes.VariableElimination.Buckets(IsBucketItem(..))

--import Debug.Trace 
--debug s a = trace (s ++ " " ++ show a ++ "\n") a

type UpMessage a = a 
type DownMessage a = Maybe a

-- | Separator value
data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a)
                      | EmptySeparator -- ^ Use to track the progress in the collect phase
                      deriving(Eq)

instance Show a => Show (SeparatorValue a) where 
    show EmptySeparator = ""
    show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")"
    show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")"

type FactorValues a = [a]
type EvidenceValues a = [a]

-- | Node value
data NodeValue a = NodeValue !Vertex !(FactorValues a) !(EvidenceValues a) deriving(Eq)

instance Show a => Show (NodeValue a) where 
    show (NodeValue v f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")"

newtype Sep = Sep Int deriving(Eq,Ord,Show,Num,Binary)

-- | Junction tree.
-- 'c' is the node / separator identifier (for instance a set of 'DV')
-- a are the values for a node or separator
-- Cluster are unique sor the cluster value is also the cluster key
-- Separator values are not unique. Two different seperators can be the same
-- cluster. So, separator unicity is enforced with a number
data JTree  c f = JTree {  root :: !c
                        -- | Leaves of the tree
                        ,  leavesSet :: !(Set.Set c)
                        -- | The children of a node are separators
                        ,  childrenMap :: !(Map.Map c [Sep])
                        -- | Parent of a node
                        ,  parentMap :: !(Map.Map c Sep)
                        -- | Parent of a separator
                        ,  separatorParentMap :: !(Map.Map Sep c)
                        -- | The child of a seperator is a node
                        ,  separatorChildMap :: !(Map.Map Sep c)
                        -- | Values for nodes and seperators
                        ,  nodeValueMap :: !(Map.Map c (NodeValue f))
                        ,  separatorValueMap :: !(Map.Map Sep (SeparatorValue f))
                        ,  separatorCurrentKey :: !Sep
                        ,  separatorClusterMap :: !(Map.Map Sep c)
                        } deriving(Eq)

instance FactorContainer (JTree Cluster) where 
  changeFactor f t = 
    let changeNodeValue (NodeValue v fa ev) = NodeValue v (changeFactor f fa) ev
    in
    distribute .
    collect $
    t { nodeValueMap = Map.map changeNodeValue (nodeValueMap t)
      , separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)
      }

-- | Create a singleton tree with just one root node
singletonTree r rootVertex factorValue evidenceValue = 
    let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty (Sep 0) Map.empty
    in 
    addNode r rootVertex factorValue evidenceValue t

-- | Reset all evidences to 1 in the network
resetEvidences :: Factor f => JTree c f -> JTree c f 
resetEvidences t = t {nodeValueMap = Map.map resetNodeEvidence (nodeValueMap t)}
 where 
  resetNodeEvidence (NodeValue v f _) = NodeValue v f []

-- | Get the cluster for a separator
separatorCluster :: JTree c a -> Sep -> c 
{-# INLINE separatorCluster #-}
separatorCluster t s = fromJust $! Map.lookup s (separatorClusterMap t)

-- | Leaves of the tree
leaves :: JTree c a -> [c]
{-# INLINE leaves #-}
leaves = Set.toList . leavesSet

-- | All nodes of the tree
treeNodes :: JTree c a -> [c]
{-# INLINE treeNodes #-}
treeNodes = Map.keys . nodeValueMap

treeValues :: JTree c f -> [(c,NodeValue f)]
{-# INLINE treeValues #-}
treeValues = Map.toList . nodeValueMap

-- | Value of a node
nodeValue :: Ord c => JTree c a -> c -> NodeValue a 
{-# INLINE nodeValue #-}
nodeValue t e = fromJust $! Map.lookup e (nodeValueMap t)

-- | Change the value of a node
setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a
{-# INLINE setNodeValue #-}
setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)} 

-- | Parent of a node
nodeParent :: Ord c => JTree c a -> c -> Maybe Sep 
{-# INLINE nodeParent #-}
nodeParent t e = let r = Map.lookup e (parentMap t) in r `seq` r

-- | Value of a node
separatorValue :: Ord c => JTree c a -> Sep -> SeparatorValue a 
{-# INLINE separatorValue #-}
separatorValue t e = fromJust $! Map.lookup e (separatorValueMap t)

-- | Parent of a separator
separatorParent :: Ord c => JTree c a -> Sep -> c 
{-# INLINE separatorParent #-}
separatorParent t e = fromJust $! Map.lookup e (separatorParentMap t)

-- | UpMessage for a separator node
upMessage :: Ord c => JTree c a -> Sep -> a
upMessage t c = case separatorValue t c of 
                  SeparatorValue up _ -> up 
                  _ -> error "Trying to get an up message on an empty seperator ! Should never occur !"

-- | DownMessage for a separator node
downMessage :: Ord c => JTree c a -> Sep -> Maybe a 
downMessage t c = case separatorValue t c of 
     SeparatorValue _ (Just down) -> Just down 
     SeparatorValue _ Nothing -> Nothing
     _ -> error "Trying to get a down message on an empty separator ! Should never occur !"

-- | Return the separator childrens of a node
nodeChildren :: Ord c => JTree c a -> c -> [Sep]
{-# INLINE nodeChildren #-}
nodeChildren t e = maybe [] id $! Map.lookup e (childrenMap t)

-- | Return the child of a separator
separatorChild :: Ord c => JTree c a -> Sep -> c 
{-# INLINE separatorChild #-}
separatorChild t e = fromJust $! Map.lookup e (separatorChildMap t)

-- | Check if a node is member of the tree
nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool 
{-# INLINE nodeIsMemberOfTree #-}
nodeIsMemberOfTree c t = Map.member c (nodeValueMap t)

-- | Add a separator between two nodes.
-- The nodes MUST already be in the tree
addSeparator :: (Ord c) 
             => c -- ^ Origin node 
             -> c -- ^ Separator value
             -> c -- ^ Destination node 
             -> JTree c a -- ^ Current tree 
             -> JTree c a -- ^ Modified tree 
addSeparator node sepCluster dest t = 
    let newSep = (separatorCurrentKey t) + 1 
    in
    t { childrenMap = Map.insertWith' (++) node [newSep] (childrenMap t)
      , separatorChildMap = Map.insert newSep dest (separatorChildMap t)
      , separatorValueMap = Map.insert newSep EmptySeparator (separatorValueMap t)
      , separatorClusterMap = Map.insert newSep sepCluster (separatorClusterMap t)
      , leavesSet = Set.delete node (leavesSet t) 
      , parentMap = Map.insert dest newSep (parentMap t)
      , separatorParentMap = Map.insert newSep node (separatorParentMap t)
      , separatorCurrentKey = newSep
      }

-- | Add a new node
addNode :: (Ord c) 
        => c -- ^ Node
        -> Vertex
        -> [a] -- ^ Factor value 
        -> [a] -- ^ Evidence value
        -> JTree c a 
        -> JTree c a 
addNode node vertex factorValue evidenceValue t = 
   t { nodeValueMap = Map.insert node (NodeValue vertex factorValue evidenceValue) (nodeValueMap t)
     , leavesSet = Set.insert node (leavesSet t) 
     }

-- | Update the up message of a separator
updateUpMessage :: Ord c 
                => Maybe Sep -- ^ Separator node to update (if any : none for root node)
                -> a -- ^ New value
                -> JTree c a -- ^ Old tree
                -> JTree c a
updateUpMessage Nothing _ t = t
updateUpMessage (Just sep) newval t = 
    let newSepValue =  case separatorValue t sep of 
                         EmptySeparator -> SeparatorValue newval Nothing
                         SeparatorValue up down -> SeparatorValue newval down 
    in 
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}

-- | Update the down message of a separator
updateDownMessage :: Ord c 
                  => Sep -- ^ Separator node to update
                  -> a -- ^ New value
                  -> JTree c a -- ^ Old tree
                  -> JTree c a
updateDownMessage sep newval t = 
    let newSepValue = case separatorValue t sep of 
                        EmptySeparator -> error "Can't set a down message on an empty separator"
                        SeparatorValue up _ -> SeparatorValue up (Just newval)
    in 
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}

{-

Message passing algorithms
    
-}

-- | Functions used to generate new messages
class Message f c | f -> c where
    -- | Generate a new message from the received ones
    newMessage :: [f] -> NodeValue f -> c -> f 


-- | Check that a separator is initialized
separatorInitialized :: SeparatorValue a -> Bool
separatorInitialized EmptySeparator = False 
separatorInitialized _ = True

allSeparatorsHaveReceivedAMessage :: Ord c
                                  => JTree c a -- ^ Tree
                                  -> [Sep] -- ^ Separators
                                  -> Bool 
allSeparatorsHaveReceivedAMessage t seps = 
  all separatorInitialized . map (separatorValue t) $ seps

-- | Update the up separator by sending a message
-- But only if all the down separators have received a message
updateUpSeparator :: (Message a c, Ord c) 
                  => JTree c a 
                  -> c -- ^ Node generating the new upMessage
                  -> JTree c a 
updateUpSeparator t h  = 
    let seps = nodeChildren t h
    in
    case allSeparatorsHaveReceivedAMessage t seps of 
      False -> t 
      True -> let incomingMessages = map (upMessage t) seps
                  currentValue = nodeValue t h
                  destinationNode = nodeParent t h
              in 
              case destinationNode of 
                Nothing -> t -- When root
                Just p -> let  sepC = separatorCluster t p
                               generatedMessage = newMessage incomingMessages currentValue sepC
                          in 
                          updateUpMessage destinationNode generatedMessage t

-- | Update the down separator by sending a message
updateDownSeparator :: (Message a c, Ord c) 
                    => c -- ^ Node generating the message 
                    -> JTree c a 
                    -> Sep -- ^ Child receiving the message
                    -> JTree c a 
updateDownSeparator node t child  = 
    let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child])
        messageFromAbove = downMessage t =<< (nodeParent t node)
        incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove
        currentValue = nodeValue t node
        childC = separatorCluster t child
        generatedMessage = newMessage incomingMessages currentValue childC
    in 
    updateDownMessage child generatedMessage t

unique :: Ord c => [c] -> [c]
{-# INLINE unique #-}
unique = Set.toList . Set.fromList

-- | Collect message taking into account that the tree depth may be different for different leaves.
collect :: (Ord c, Message a c) 
        => JTree c a 
        -> JTree c a
collect t = _collectNodes (leaves t) t

_collectSeparators :: (Ord c, Message a c) 
                   => [Sep]
                   -> JTree c a -- ^ Tree
                   -> JTree c a -- ^ Modified tree
_collectSeparators l t = _collectNodes (unique . map (separatorParent t) $ l) t

_collectNodes :: (Ord c, Message a c) 
              => [c]
              -> JTree c a -- ^ Tree
              -> JTree c a -- ^ Modified tree 
_collectNodes  [] t = t
_collectNodes  l t = 
    let newTree = foldl' updateUpSeparator t l
    in 
    _collectSeparators (mapMaybe (nodeParent t) l) newTree
    
distribute :: (Ord c, Message a c)
           => JTree c a 
           -> JTree c a
distribute t = _distributeNodes t (root t) 

_distributeSeparators :: (Ord c, Message a c)
                      => JTree c a 
                      -> Sep -- ^ Destination of the distribute
                      -> JTree c a 
_distributeSeparators t node = _distributeNodes t (separatorChild t node)

_distributeNodes :: (Ord c, Message a c)
                 => JTree c a 
                 -> c -- ^ Destination of the distribute
                 -> JTree c a 
_distributeNodes t node = 
    let children = nodeChildren t node
        newTree = foldl' (updateDownSeparator node) t $ children
    in
    foldl' _distributeSeparators newTree children

{-

Factors and evidence modifications

-}

-- | This class is used to check if evidence or a factor is relevant
-- for a cluster
class IsCluster c where 
  -- | Evidence contained in the cluster
  overlappingEvidence :: c -> [DVI] -> [DVI]
  -- | Cluser variables
  clusterVariables :: c -> [DV]
  -- | Intersection of two clusters
  mkSeparator :: c -> c -> c

instance IsCluster [DV] where 
  overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e
  clusterVariables = id
  mkSeparator = intersect

data Action s a = Skip !s 
                | ModifyAndStop !s !a
                | Modify !s !a
                | Stop !s

-- | Traverse a tree and modify it
traverseTree :: Ord c 
             => (s -> c -> NodeValue f -> Action s (NodeValue f)) -- ^ Modification function
             -> s -- ^ Current state
             -> JTree c f -- ^ Input tree
             -> (JTree c f,s)
traverseTree action state t = _traverseTreeNodes action (t,state) (root t)

_traverseTreeSeparators action (t,state) current = _traverseTreeNodes  action (t,state)  (separatorChild t current) 

_traverseTreeNodes action (t,state) current = 
  case action state current (nodeValue t current) of 
     Stop newState -> (t,newState)
     ModifyAndStop _ newValue -> (setNodeValue current newValue t, state) 
     Skip newState -> foldl' (_traverseTreeSeparators action) (t,newState) (nodeChildren t current)
     Modify newState newValue -> 
         let newTree = setNodeValue current newValue t 
         in 
         foldl' (_traverseTreeSeparators action) (newTree,newState) (nodeChildren newTree current)

mapWithCluster :: Ord c 
               => (c -> NodeValue f -> NodeValue f)
               -> JTree c f 
               -> JTree c f        
mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)}

-- | Set the factors in the tree 
updateTreeValues :: (Factor f, IsCluster c, Ord c, Show c, Show f)
                 => (f -> NodeValue f -> NodeValue f) 
                 -> [f]
                 -> JTree c f 
                 -> JTree c f
updateTreeValues change factors t = 
  let allNodes = treeNodes t
      factorIncludedInCluster f c = all (`elem` clusterVariables c) (factorVariables f)
      coveringClusters f = filter (f `factorIncludedInCluster`) allNodes
      clusterSize a = product . map (fromIntegral . dimension) . clusterVariables $ a :: Integer
      addFactor t newFactor = 
        let minimumCluster = minimumBy (compare `on` clusterSize) (coveringClusters newFactor)
            clusterValue = nodeValue t minimumCluster
        in 
        setNodeValue minimumCluster (change newFactor clusterValue) t
  in 
  foldl' addFactor t factors

-- | Set the factors in the tree 
setFactors :: (Graph g, Factor f, IsCluster c, Ord c, Show c, Show f)
           => BayesianNetwork g f 
           -> JTree c f 
           -> JTree c f
setFactors g t = 
  let factors = allVertexValues g 
      changeFactor f (NodeValue v oldf e) = NodeValue v (f:oldf) e
  in 
  updateTreeValues  changeFactor factors t  

-- | Change evidence in the network
changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c, Show c, Show f)
               => [DVI] -- ^ Evidence
               -> JTree c f 
               -> JTree c f 
changeEvidence e t =  
  let evidences = map factorFromInstantiation e
      changeEvidence newe (NodeValue v f olde) = NodeValue v f (newe:olde)
  in 
  distribute . 
  collect .
  updateTreeValues changeEvidence evidences .
  resetEvidences $
  t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)}                

-- | Cluster of discrete variables.
-- Discrete variables instead of vertices are needed because the
-- factor are using 'DV' and we need to find
-- which factors must be contained in a given cluster.
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord)

instance IsCluster Cluster where 
  overlappingEvidence c = overlappingEvidence (fromCluster c)
  clusterVariables c = clusterVariables (fromCluster c)
  mkSeparator (Cluster a) (Cluster b) = Cluster (Set.intersection a b)

instance Show Cluster where 
  show (Cluster s) = show . Set.toList $ s

fromCluster (Cluster s) = Set.toList s 


instance (Factor f,IsBucketItem f) => Message f Cluster where 
  newMessage input (NodeValue _ f e) c = 
    let allFactors = f ++ e ++ input 
        variablesToKeep = fromCluster c 
        variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ variablesToKeep
    in 
    marginal allFactors variablesToRemove variablesToKeep []


type JunctionTree f = JTree Cluster f

{-

Implement the show function to see the structure of the tree
(without the values)
    
-}

data NodeKind c = N !c | S !c

label True c a = c ++ "=" ++ show a 
label False c _ = c

-- | Convert the JTree into a tree of string
-- using the cluster.
toTree :: (Ord c, Show c, Show a) 
       => Bool -- ^ True if the data must be displayed
       -> JTree c a 
       -> Tree.Tree String
toTree d t = 
    let r = root t
        v = nodeValue t r
        nodec = nodeChildren t r
    in 
    Tree.Node (label d (show r) v) (_toTreeSeparators d t nodec)


_toTreeNodes :: (Ord c, Show c, Show a) 
             => Bool
             -> JTree c a 
             -> [c] 
             -> [Tree.Tree String]
_toTreeNodes _ _ [] = []
_toTreeNodes d t (h:l) = 
    let nodec = nodeChildren t h -- Node children are separators
        v = nodeValue t h
    in
    Tree.Node (label d (show h) v) (_toTreeSeparators d t nodec):_toTreeNodes d t l

_toTreeSeparators :: (Ord c, Show c, Show a) 
                  => Bool
                  -> JTree c a 
                  -> [Sep] 
                  -> [Tree.Tree String]
_toTreeSeparators _ _ [] = []    
_toTreeSeparators d t (h:l) = 
    let separatorc = [separatorChild t h] -- separator child is a node
        v = separatorValue t h
    in
    Tree.Node (label d ("<" ++ show (separatorCluster t h) ++ ">") v ) (_toTreeNodes d t separatorc):_toTreeSeparators d t l

instance (Ord c, Show c, Show a) => Show (JTree c a) where 
    show = Tree.drawTree . toTree False

displayTree b = Tree.drawTree . toTree b

-- | Display the tree values
displayTreeValues :: (Show f, Show c) => JTree c f -> IO ()
displayTreeValues t = 
  let allValues = treeValues t
      printAValue (c,NodeValue _ f e) = do 
        print c 
        putStrLn "FACTOR"
        print f 
        putStrLn "EVIDENCE"
        print e 
        putStrLn "------"

  in 
  mapM_ printAValue allValues