module Bayes.FactorElimination.JTree(
      IsCluster(..)
    , Cluster(..)
    , JTree(..)
    , JunctionTree(..)
    , Sep
    , setFactors
    , distribute 
    , collect
    , fromCluster
    , changeEvidence
    , nodeIsMemberOfTree
    , singletonTree
    , addNode 
    , addSeparator
    , leaves
    , nodeValue
    , NodeValue(..)
    , SeparatorValue(..)
    , downMessage
    , upMessage 
    , nodeParent 
    , nodeChildren
    , traverseTree
    , separatorChild
    , treeNodes
    , treeValues
    , displayTreeValues
    , Action(..)
    ) where 
import qualified Data.Map as Map
import qualified Data.Tree as Tree
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.Set as Set
import Data.Monoid
import Data.List((\\), intersect,partition, foldl',minimumBy,nub)
import Bayes.PrivateTypes 
import Bayes.Factor
import Bayes
import Data.Function(on)
import Bayes.VariableElimination(marginal)
import Data.Binary
import Bayes.VariableElimination.Buckets(IsBucketItem(..))
type UpMessage a = a 
type DownMessage a = Maybe a
data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a)
                      | EmptySeparator 
                      deriving(Eq)
instance Show a => Show (SeparatorValue a) where 
    show EmptySeparator = ""
    show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")"
    show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")"
type FactorValues a = [a]
type EvidenceValues a = [a]
data NodeValue a = NodeValue !Vertex !(FactorValues a) !(EvidenceValues a) deriving(Eq)
instance Show a => Show (NodeValue a) where 
    show (NodeValue v f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")"
newtype Sep = Sep Int deriving(Eq,Ord,Show,Num,Binary)
data JTree  c f = JTree {  root :: !c
                        
                        ,  leavesSet :: !(Set.Set c)
                        
                        ,  childrenMap :: !(Map.Map c [Sep])
                        
                        ,  parentMap :: !(Map.Map c Sep)
                        
                        ,  separatorParentMap :: !(Map.Map Sep c)
                        
                        ,  separatorChildMap :: !(Map.Map Sep c)
                        
                        ,  nodeValueMap :: !(Map.Map c (NodeValue f))
                        ,  separatorValueMap :: !(Map.Map Sep (SeparatorValue f))
                        ,  separatorCurrentKey :: !Sep
                        ,  separatorClusterMap :: !(Map.Map Sep c)
                        } deriving(Eq)
instance FactorContainer (JTree Cluster) where 
  changeFactor f t = 
    let changeNodeValue (NodeValue v fa ev) = NodeValue v (changeFactor f fa) ev
    in
    distribute .
    collect $
    t { nodeValueMap = Map.map changeNodeValue (nodeValueMap t)
      , separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)
      }
singletonTree r rootVertex factorValue evidenceValue = 
    let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty (Sep 0) Map.empty
    in 
    addNode r rootVertex factorValue evidenceValue t
resetEvidences :: Factor f => JTree c f -> JTree c f 
resetEvidences t = t {nodeValueMap = Map.map resetNodeEvidence (nodeValueMap t)}
 where 
  resetNodeEvidence (NodeValue v f _) = NodeValue v f []
separatorCluster :: JTree c a -> Sep -> c 
separatorCluster t s = fromJust $! Map.lookup s (separatorClusterMap t)
leaves :: JTree c a -> [c]
leaves = Set.toList . leavesSet
treeNodes :: JTree c a -> [c]
treeNodes = Map.keys . nodeValueMap
treeValues :: JTree c f -> [(c,NodeValue f)]
treeValues = Map.toList . nodeValueMap
nodeValue :: Ord c => JTree c a -> c -> NodeValue a 
nodeValue t e = fromJust $! Map.lookup e (nodeValueMap t)
setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a
setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)} 
nodeParent :: Ord c => JTree c a -> c -> Maybe Sep 
nodeParent t e = let r = Map.lookup e (parentMap t) in r `seq` r
separatorValue :: Ord c => JTree c a -> Sep -> SeparatorValue a 
separatorValue t e = fromJust $! Map.lookup e (separatorValueMap t)
separatorParent :: Ord c => JTree c a -> Sep -> c 
separatorParent t e = fromJust $! Map.lookup e (separatorParentMap t)
upMessage :: Ord c => JTree c a -> Sep -> a
upMessage t c = case separatorValue t c of 
                  SeparatorValue up _ -> up 
                  _ -> error "Trying to get an up message on an empty seperator ! Should never occur !"
downMessage :: Ord c => JTree c a -> Sep -> Maybe a 
downMessage t c = case separatorValue t c of 
     SeparatorValue _ (Just down) -> Just down 
     SeparatorValue _ Nothing -> Nothing
     _ -> error "Trying to get a down message on an empty separator ! Should never occur !"
nodeChildren :: Ord c => JTree c a -> c -> [Sep]
nodeChildren t e = maybe [] id $! Map.lookup e (childrenMap t)
separatorChild :: Ord c => JTree c a -> Sep -> c 
separatorChild t e = fromJust $! Map.lookup e (separatorChildMap t)
nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool 
nodeIsMemberOfTree c t = Map.member c (nodeValueMap t)
addSeparator :: (Ord c) 
             => c 
             -> c 
             -> c 
             -> JTree c a 
             -> JTree c a 
addSeparator node sepCluster dest t = 
    let newSep = (separatorCurrentKey t) + 1 
    in
    t { childrenMap = Map.insertWith' (++) node [newSep] (childrenMap t)
      , separatorChildMap = Map.insert newSep dest (separatorChildMap t)
      , separatorValueMap = Map.insert newSep EmptySeparator (separatorValueMap t)
      , separatorClusterMap = Map.insert newSep sepCluster (separatorClusterMap t)
      , leavesSet = Set.delete node (leavesSet t) 
      , parentMap = Map.insert dest newSep (parentMap t)
      , separatorParentMap = Map.insert newSep node (separatorParentMap t)
      , separatorCurrentKey = newSep
      }
addNode :: (Ord c) 
        => c 
        -> Vertex
        -> [a] 
        -> [a] 
        -> JTree c a 
        -> JTree c a 
addNode node vertex factorValue evidenceValue t = 
   t { nodeValueMap = Map.insert node (NodeValue vertex factorValue evidenceValue) (nodeValueMap t)
     , leavesSet = Set.insert node (leavesSet t) 
     }
updateUpMessage :: Ord c 
                => Maybe Sep 
                -> a 
                -> JTree c a 
                -> JTree c a
updateUpMessage Nothing _ t = t
updateUpMessage (Just sep) newval t = 
    let newSepValue =  case separatorValue t sep of 
                         EmptySeparator -> SeparatorValue newval Nothing
                         SeparatorValue up down -> SeparatorValue newval down 
    in 
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
updateDownMessage :: Ord c 
                  => Sep 
                  -> a 
                  -> JTree c a 
                  -> JTree c a
updateDownMessage sep newval t = 
    let newSepValue = case separatorValue t sep of 
                        EmptySeparator -> error "Can't set a down message on an empty separator"
                        SeparatorValue up _ -> SeparatorValue up (Just newval)
    in 
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
class Message f c | f -> c where
    
    newMessage :: [f] -> NodeValue f -> c -> f 
separatorInitialized :: SeparatorValue a -> Bool
separatorInitialized EmptySeparator = False 
separatorInitialized _ = True
allSeparatorsHaveReceivedAMessage :: Ord c
                                  => JTree c a 
                                  -> [Sep] 
                                  -> Bool 
allSeparatorsHaveReceivedAMessage t seps = 
  all separatorInitialized . map (separatorValue t) $ seps
updateUpSeparator :: (Message a c, Ord c) 
                  => JTree c a 
                  -> c 
                  -> JTree c a 
updateUpSeparator t h  = 
    let seps = nodeChildren t h
    in
    case allSeparatorsHaveReceivedAMessage t seps of 
      False -> t 
      True -> let incomingMessages = map (upMessage t) seps
                  currentValue = nodeValue t h
                  destinationNode = nodeParent t h
              in 
              case destinationNode of 
                Nothing -> t 
                Just p -> let  sepC = separatorCluster t p
                               generatedMessage = newMessage incomingMessages currentValue sepC
                          in 
                          updateUpMessage destinationNode generatedMessage t
updateDownSeparator :: (Message a c, Ord c) 
                    => c 
                    -> JTree c a 
                    -> Sep 
                    -> JTree c a 
updateDownSeparator node t child  = 
    let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child])
        messageFromAbove = downMessage t =<< (nodeParent t node)
        incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove
        currentValue = nodeValue t node
        childC = separatorCluster t child
        generatedMessage = newMessage incomingMessages currentValue childC
    in 
    updateDownMessage child generatedMessage t
unique :: Ord c => [c] -> [c]
unique = Set.toList . Set.fromList
collect :: (Ord c, Message a c) 
        => JTree c a 
        -> JTree c a
collect t = _collectNodes (leaves t) t
_collectSeparators :: (Ord c, Message a c) 
                   => [Sep]
                   -> JTree c a 
                   -> JTree c a 
_collectSeparators l t = _collectNodes (unique . map (separatorParent t) $ l) t
_collectNodes :: (Ord c, Message a c) 
              => [c]
              -> JTree c a 
              -> JTree c a 
_collectNodes  [] t = t
_collectNodes  l t = 
    let newTree = foldl' updateUpSeparator t l
    in 
    _collectSeparators (mapMaybe (nodeParent t) l) newTree
    
distribute :: (Ord c, Message a c)
           => JTree c a 
           -> JTree c a
distribute t = _distributeNodes t (root t) 
_distributeSeparators :: (Ord c, Message a c)
                      => JTree c a 
                      -> Sep 
                      -> JTree c a 
_distributeSeparators t node = _distributeNodes t (separatorChild t node)
_distributeNodes :: (Ord c, Message a c)
                 => JTree c a 
                 -> c 
                 -> JTree c a 
_distributeNodes t node = 
    let children = nodeChildren t node
        newTree = foldl' (updateDownSeparator node) t $ children
    in
    foldl' _distributeSeparators newTree children
class IsCluster c where 
  
  overlappingEvidence :: c -> [DVI] -> [DVI]
  
  clusterVariables :: c -> [DV]
  
  mkSeparator :: c -> c -> c
instance IsCluster [DV] where 
  overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e
  clusterVariables = id
  mkSeparator = intersect
data Action s a = Skip !s 
                | ModifyAndStop !s !a
                | Modify !s !a
                | Stop !s
traverseTree :: Ord c 
             => (s -> c -> NodeValue f -> Action s (NodeValue f)) 
             -> s 
             -> JTree c f 
             -> (JTree c f,s)
traverseTree action state t = _traverseTreeNodes action (t,state) (root t)
_traverseTreeSeparators action (t,state) current = _traverseTreeNodes  action (t,state)  (separatorChild t current) 
_traverseTreeNodes action (t,state) current = 
  case action state current (nodeValue t current) of 
     Stop newState -> (t,newState)
     ModifyAndStop _ newValue -> (setNodeValue current newValue t, state) 
     Skip newState -> foldl' (_traverseTreeSeparators action) (t,newState) (nodeChildren t current)
     Modify newState newValue -> 
         let newTree = setNodeValue current newValue t 
         in 
         foldl' (_traverseTreeSeparators action) (newTree,newState) (nodeChildren newTree current)
mapWithCluster :: Ord c 
               => (c -> NodeValue f -> NodeValue f)
               -> JTree c f 
               -> JTree c f        
mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)}
updateTreeValues :: (Factor f, IsCluster c, Ord c, Show c, Show f)
                 => (f -> NodeValue f -> NodeValue f) 
                 -> [f]
                 -> JTree c f 
                 -> JTree c f
updateTreeValues change factors t = 
  let allNodes = treeNodes t
      factorIncludedInCluster f c = all (`elem` clusterVariables c) (factorVariables f)
      coveringClusters f = filter (f `factorIncludedInCluster`) allNodes
      clusterSize a = product . map (fromIntegral . dimension) . clusterVariables $ a :: Integer
      addFactor t newFactor = 
        let minimumCluster = minimumBy (compare `on` clusterSize) (coveringClusters newFactor)
            clusterValue = nodeValue t minimumCluster
        in 
        setNodeValue minimumCluster (change newFactor clusterValue) t
  in 
  foldl' addFactor t factors
setFactors :: (Graph g, Factor f, IsCluster c, Ord c, Show c, Show f)
           => BayesianNetwork g f 
           -> JTree c f 
           -> JTree c f
setFactors g t = 
  let factors = allVertexValues g 
      changeFactor f (NodeValue v oldf e) = NodeValue v (f:oldf) e
  in 
  updateTreeValues  changeFactor factors t  
changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c, Show c, Show f)
               => [DVI] 
               -> JTree c f 
               -> JTree c f 
changeEvidence e t =  
  let evidences = map factorFromInstantiation e
      changeEvidence newe (NodeValue v f olde) = NodeValue v f (newe:olde)
  in 
  distribute . 
  collect .
  updateTreeValues changeEvidence evidences .
  resetEvidences $
  t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)}                
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord)
instance IsCluster Cluster where 
  overlappingEvidence c = overlappingEvidence (fromCluster c)
  clusterVariables c = clusterVariables (fromCluster c)
  mkSeparator (Cluster a) (Cluster b) = Cluster (Set.intersection a b)
instance Show Cluster where 
  show (Cluster s) = show . Set.toList $ s
fromCluster (Cluster s) = Set.toList s 
instance (Factor f,IsBucketItem f) => Message f Cluster where 
  newMessage input (NodeValue _ f e) c = 
    let allFactors = f ++ e ++ input 
        variablesToKeep = fromCluster c 
        variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ variablesToKeep
    in 
    marginal allFactors variablesToRemove variablesToKeep []
type JunctionTree f = JTree Cluster f
data NodeKind c = N !c | S !c
label True c a = c ++ "=" ++ show a 
label False c _ = c
toTree :: (Ord c, Show c, Show a) 
       => Bool 
       -> JTree c a 
       -> Tree.Tree String
toTree d t = 
    let r = root t
        v = nodeValue t r
        nodec = nodeChildren t r
    in 
    Tree.Node (label d (show r) v) (_toTreeSeparators d t nodec)
_toTreeNodes :: (Ord c, Show c, Show a) 
             => Bool
             -> JTree c a 
             -> [c] 
             -> [Tree.Tree String]
_toTreeNodes _ _ [] = []
_toTreeNodes d t (h:l) = 
    let nodec = nodeChildren t h 
        v = nodeValue t h
    in
    Tree.Node (label d (show h) v) (_toTreeSeparators d t nodec):_toTreeNodes d t l
_toTreeSeparators :: (Ord c, Show c, Show a) 
                  => Bool
                  -> JTree c a 
                  -> [Sep] 
                  -> [Tree.Tree String]
_toTreeSeparators _ _ [] = []    
_toTreeSeparators d t (h:l) = 
    let separatorc = [separatorChild t h] 
        v = separatorValue t h
    in
    Tree.Node (label d ("<" ++ show (separatorCluster t h) ++ ">") v ) (_toTreeNodes d t separatorc):_toTreeSeparators d t l
instance (Ord c, Show c, Show a) => Show (JTree c a) where 
    show = Tree.drawTree . toTree False
displayTree b = Tree.drawTree . toTree b
displayTreeValues :: (Show f, Show c) => JTree c f -> IO ()
displayTreeValues t = 
  let allValues = treeValues t
      printAValue (c,NodeValue _ f e) = do 
        print c 
        putStrLn "FACTOR"
        print f 
        putStrLn "EVIDENCE"
        print e 
        putStrLn "------"
  in 
  mapM_ printAValue allValues