module Bayes.InfluenceDiagram(
  
    InfluenceDiagram
  , DecisionFactor
  , Instantiable(..)
  , DEV 
  , UV
  , DV
  , TDV
  
  , t 
  , (~~)
  , chance 
  , decisionNode 
  , utilityNode 
  , proba 
  , decision 
  , utility
  , cpt
  , d 
  , p
  , noDependencies
  
  , decisionsOrder
  , solveInfluenceDiagram
  , runID
  , policyNetwork
  , decisionToInstantiation
  
  , DVISet 
  , DVI
	) where 
import Bayes
import Bayes.PrivateTypes
import Bayes.Network 
import Data.Monoid
import Bayes.Factor
import Bayes.Factor.PrivateCPT
import Bayes.Factor.CPT
import Bayes.Factor.MaxCPT
import Data.Maybe(fromJust,mapMaybe)
import Control.Applicative((<$>))
import Bayes.VariableElimination.Buckets
import Bayes.Factor.PrivateCPT(DecisionFactor(..),decisionFactor,convertToMaxFactor,convertToNormalFactor,privateFactorValue,_factorVariables)
import Data.List(foldl1',foldl')
import Control.Monad.State.Strict(gets)
import qualified Data.Vector as NV
import qualified Data.Map as M 
import qualified Data.IntMap as IM 
replaceDecisionNodeWithPolicy :: InfluenceDiagram -> CPT -> InfluenceDiagram 
replaceDecisionNodeWithPolicy g f = 
  let dv = factorMainVariable f 
      parentVariables = tail (factorVariables f)
      factorV = vertex dv
      g' = fromJust $ changeVertexValue factorV (DecisionNode f) g
      oldParentEdges = fromJust $ ingoing g' factorV 
      g'' = foldr removeEdge g' oldParentEdges
      addNewFactorEdge pdv currentG = addEdge (edge (vertex pdv) (vertex dv)) NormalLink currentG
  in 
  
  
  
  foldr addNewFactorEdge g'' parentVariables
policyFactor :: DecisionFactor -> CPT
policyFactor (Scalar v) = 
  let decisionVar = v 
      originalV = vertex decisionVar 
      nakedVars = [dv decisionVar]
      allVars = DVSet nakedVars
      values = do 
        x <- forAllInstantiations allVars 
        if (instantiationValue v == instantiationValue (head x))
          then 
            return 1.0 
          else 
            return 0.0
  in 
  fromJust . factorWithVariables nakedVars $ values 
policyFactor f@(Table d m v) = 
  let decisionVar = NV.head v 
      originalV = vertex decisionVar 
      nakedVars = dv decisionVar : d
      allVars = DVSet nakedVars
      values = do 
        x <- forAllInstantiations allVars 
        let v = privateFactorValue f (tail x)
        if (instantiationValue v == instantiationValue (head x))
          then 
            return 1.0 
          else 
            return 0.0
  in
  fromJust . factorWithVariables nakedVars $ values
  
decisionToInstantiation :: DecisionFactor -> [DVISet]
decisionToInstantiation f@(Scalar v) = [[v]]
decisionToInstantiation f@(Table d m v) = 
  let allVars = DVSet d
      values = do 
        x <- forAllInstantiations allVars 
        let v = privateFactorValue f x 
        return (v:x)
  in 
  values
policyNetwork :: [DecisionFactor] -> InfluenceDiagram -> SBN CPT 
policyNetwork l idg = 
  let idg1 = foldl' replaceDecisionNodeWithPolicy idg (map policyFactor l) 
      utilities = filter (isUtilityNode idg) . allVertices $ idg
      SP e v b = foldr removeVertex idg1 utilities
      toBayesNode (l,ChanceNode f) = (l,f) 
      toBayesNode (l,DecisionNode f) = (l,f)
      toBayesNode (l,UtilityNode _ _) = error "No utilities nodes should remain to create the policy network"
      e' = M.map (const ()) e 
      v' = IM.map toBayesNode v 
  in 
  SP e' v' b
instance Show DecisionFactor where
    show (Scalar v) = "\nScalar Factor:\n" ++ show v
    show c@(Table [] _ v) = "\nEmpty DecisionFactor:\n"
    show c = displayFactorBody c 
instance MultiDimTable DecisionFactor where 
    elementStringValue f d = show (privateFactorValue f d)
    tableVariables = _factorVariables
data JoinSum = JS !CPT !CPT deriving(Eq)
instance Show JoinSum where 
  show (JS p u) = "CPT\n" ++ show p ++ "\nUTILITY\n" ++ show u ++ "\n"
chanceFactor f = JS f (factorFromScalar 0.0)
utilityFactor f = JS (factorFromScalar 1.0) f
jsProduct ::  JoinSum -> JoinSum -> JoinSum
jsProduct (JS pa ua) (JS pb ub) = JS (itemProduct [pa,pb]) (cptSum [ua,ub])
maximalize :: DV -> [JoinSum] -> (JoinSum,DecisionFactor) 
maximalize dv l = 
  let JS pa ua = itemProduct l
      maxa = convertToNormalFactor . itemProjectOut dv . convertToMaxFactor $ pa
      maxu' = itemProjectOut dv . convertToMaxFactor  . itemProduct $ [pa,ua]
      maxu = convertToNormalFactor maxu'
      instF = decisionFactor maxu'
  in 
  (JS maxa (cptDivide maxu maxa),instF)
instance IsBucketItem JoinSum where
    scalarItem (JS a b) = isScalarFactor a && isScalarFactor b
    itemProduct l = foldl1' jsProduct l
    itemProjectOut dv (JS pa ua) =  
      let suma = itemProjectOut dv pa
          sumu = itemProjectOut dv (itemProduct [pa,ua])
      in 
      JS suma (cptDivide sumu suma)
    itemContainsVariable (JS a b) dv = containsVariable a dv || containsVariable b dv
t = undefined
data EdgeKind = NormalLink 
              deriving(Eq,Show)
isInformationLink :: InfluenceDiagram -> Edge -> Bool
isInformationLink g (Edge va vb) = 
  (isChanceNode g va || isDecisionNode g va) && (isDecisionNode g vb)
isRevealedChanceNode :: InfluenceDiagram -> Vertex -> Bool 
isRevealedChanceNode g v = isChanceNode g v && any (isDecisionNode g) (childrenNodes g v)
edgeShape :: InfluenceDiagram -> Edge -> EdgeKind -> Maybe String
edgeShape g e NormalLink | isInformationLink g e = Just "style=dashed"
                         | otherwise = Nothing 
edgeColor :: InfluenceDiagram -> Edge -> EdgeKind -> Maybe String
edgeColor _ _ _ = Nothing
nodeShape :: InfluenceDiagram -> Vertex -> IDValue -> Maybe String
nodeShape _ _ (ChanceNode _) = Just "shape=ellipse" 
nodeShape _ _ (UtilityNode _ _) = Just "shape=diamond"
nodeShape _ _ (DecisionNode _) = Just "shape=box"
nodeColor :: InfluenceDiagram -> Vertex -> IDValue -> Maybe String
nodeColor g v (ChanceNode _) | isRevealedChanceNode g v = Just "style=filled,fillcolor=gray"
                             | otherwise = Nothing
nodeColor _ _ _ = Nothing
instance Show InfluenceDiagram where
  show g = displaySimpleGraph (nodeShape g) (nodeColor g) (edgeShape g) (edgeColor g) g
instance Monoid EdgeKind where 
   mempty = NormalLink 
   NormalLink `mappend` NormalLink = NormalLink
type InfluenceDiagram = DirectedSG EdgeKind IDValue
type IDMonad g a = NetworkMonad g EdgeKind IDValue a
data IDValue   = ChanceNode !CPT
               | UtilityNode !DV !CPT 
               | DecisionNode !CPT
               deriving(Eq)
dvFromIDValue (ChanceNode f) = factorMainVariable f
dvFromIDValue (UtilityNode dv f) = dv
dvFromIDValue (DecisionNode f) = factorMainVariable f
factorVariablesFromIDValue (ChanceNode f) = factorVariables f
factorVariablesFromIDValue (UtilityNode _ f) = factorVariables f
factorVariablesFromIDValue (DecisionNode f) = factorVariables f
jsFromIDValue (ChanceNode f) = chanceFactor f
jsFromIDValue (UtilityNode _ f) = utilityFactor f
jsFromIDValue (DecisionNode _) = error "You don't need to get the factor for a decision node"
instance Show IDValue where 
   show (ChanceNode f) = "CHANCE:\n" ++ show f
   show (UtilityNode _ f) = "UTILITY:\n" ++ show f
   show (DecisionNode f) = ""
data UV = UV !Vertex !Int deriving(Eq)
data DEV = DEV !Vertex !Int deriving(Eq,Ord)
instance Show DEV where
    show (DEV v d) = "D" ++ show v ++ "(" ++ show d ++ ")"
instance BayesianDiscreteVariable DEV where 
  dimension (DEV _ d) = d 
  dv (DEV v d) = DV v d 
  vertex (DEV v _) = v
instance Instantiable DEV Int where 
  (=:) d@(DEV v dim) value = DVI (dv d) value
data PorD = P DV | D DEV deriving(Eq)
class ChanceVariable m where 
  toDV :: m -> DV
instance ChanceVariable DV where 
  toDV = dv 
instance ChanceVariable (TDV s) where 
  toDV = dv
instance BayesianDiscreteVariable PorD where
    dimension (D d) = dimension d
    dimension (P p) = dimension p
    dv (D x) = dv x
    dv (P x) = dv x
    vertex (D d) = vertex d
    vertex (P p) = vertex p
p :: ChanceVariable c => c -> PorD
p = P . toDV
d :: DEV -> PorD 
d = D
chance :: (Bounded a, Enum a, NamedGraph g)
       => String 
       -> a 
       -> IDMonad g (TDV a)
chance = variable
utilityNode :: (NamedGraph g)
            => String 
            -> IDMonad g UV
utilityNode s = do
  DV v i <- variableWithSize s 1
  return (UV v i)
decisionNode :: (Bounded a, Enum a, NamedGraph g)
             => String 
             -> a
             -> IDMonad g DEV
decisionNode s a =  do
  DV v i <- variable s a >>= return . dv
  return (DEV v i)
utilityCpt :: (DirectedGraph g, Distribution d, Factor f) 
           => Vertex 
           -> d 
           -> NetworkMonad g e a (Maybe f) 
utilityCpt v l  = do 
  g <- gets snd
  let vertices = map (fromJust . startVertex g) . fromJust . ingoing g $ v
  fv <- mapM factorVariable vertices
  let cpt = createFactor (map fromJust fv) l
  return cpt
class Initializable v where 
  (~~) :: (DirectedGraph g, Distribution d) 
     => IDMonad g v 
     -> d 
     -> IDMonad g ()
instance Initializable DV where
  (~~) mv l = do 
     (DV v _) <- mv >>= return . dv 
     maybeNewValue <- getCpt v l
     currentValue <- getBayesianNode v
     case (currentValue, maybeNewValue) of 
       (Just c, Just n) -> initializeNodeWithValue v c (ChanceNode n)
       _ -> return ()
instance Initializable (TDV s) where
  (~~) mv l = do 
     (DV v _) <- mv >>= return . dv 
     maybeNewValue <- getCpt v l
     currentValue <- getBayesianNode v
     case (currentValue, maybeNewValue) of 
       (Just c, Just n) -> initializeNodeWithValue v c (ChanceNode n)
       _ -> return ()
instance Initializable UV where
  (~~) mv l = do 
     (UV v dim) <- mv  
     maybeNewValue <- utilityCpt v l
     currentValue <- getBayesianNode v
     case (currentValue, maybeNewValue) of 
       (Just c, Just n) -> initializeNodeWithValue v c (UtilityNode (DV v dim) n)
       _ -> return ()
instance Initializable DEV where
  (~~) mv l = do 
     (DV v _) <- mv >>= return . dv 
     maybeNewValue <- getCpt v l
     currentValue <- getBayesianNode v
     case (currentValue, maybeNewValue) of 
       (Just c, Just n) -> initializeNodeWithValue v c (DecisionNode n)
       _ -> return ()
_cpt :: (DirectedGraph g , BayesianDiscreteVariable v,BayesianDiscreteVariable vb) => v -> [vb] -> IDMonad g v
_cpt node conditions = do
  mapM_ ((dv node) <--) (reverse (map dv conditions))
  return node
cpt :: (DirectedGraph g ,BayesianDiscreteVariable vb, ChanceVariable c) => c -> [vb] -> IDMonad g c
cpt node conditions = do
  mapM_ ((toDV node) <--) (reverse (map dv conditions))
  return node
proba :: (ChanceVariable c, DirectedGraph g) => c -> IDMonad g c
proba node = cpt node ([] :: [DV])
utility :: (DirectedGraph g , BayesianDiscreteVariable dv) => UV -> [dv] -> IDMonad g UV
utility (UV v d) l = do 
  DV v' d' <- _cpt (DV v d) l
  return (UV v' d')
noDependencies :: [DV]
noDependencies = []
decision :: (DirectedGraph g, BayesianDiscreteVariable dv) => DEV -> [dv] -> IDMonad g DEV
decision d l = do 
  let dim = product . map dimension $ dv d:map dv l
  _cpt d l ~~ (replicate dim 1.0)
  return d
runID :: IDMonad DirectedSG a -> (a,InfluenceDiagram)
runID = runNetwork 
maybeOnlyResult :: [a] -> Maybe a 
maybeOnlyResult [a] = Just a 
maybeOnlyResult _ = Nothing
isDecisionNode :: InfluenceDiagram -> Vertex -> Bool 
isDecisionNode g v = maybe False (const True) $ do
  DecisionNode f <- vertexValue g v
  return f
isUtilityNode :: InfluenceDiagram -> Vertex -> Bool 
isUtilityNode g v = maybe False (const True) $ do
  UtilityNode _ f <- vertexValue g v
  return f
isChanceNode :: InfluenceDiagram -> Vertex -> Bool 
isChanceNode g v = maybe False (const True) $ do
  ChanceNode f <- vertexValue g v
  return f
isRootDecision :: InfluenceDiagram -> Vertex -> Bool
isRootDecision g v | isDecisionNode g v = 
  case ingoing g v of 
    Just [] -> True 
    _ -> False
                   | otherwise = False
chanceParents :: DEV -> InfluenceDiagram -> (InfluenceDiagram,[DV])
chanceParents dev currentG = 
  let p = filter (isChanceNode currentG) . parentNodes currentG $ (vertex dev) 
      theParents = map (vertexToDV currentG) p
      newG = foldr removeVertex currentG (vertex dev : p)
  in 
  (newG,theParents)
remainingChanceNodes :: InfluenceDiagram -> [DV]
remainingChanceNodes = chanceNodes 
utilityNodes :: InfluenceDiagram -> [UV]
utilityNodes g = map (vertexToUV g) . filter (isUtilityNode g) . allVertices $ g
chanceNodes :: InfluenceDiagram -> [DV]
chanceNodes g = map (vertexToDV g) . filter (isChanceNode g) . allVertices $ g
chanceAndUtilityFactors :: InfluenceDiagram -> [JoinSum]
chanceAndUtilityFactors g = map (jsFromIDValue . fromJust . vertexValue g) . filter (not . isDecisionNode g) . allVertices $ g
vertexToDV :: InfluenceDiagram -> Vertex -> DV 
vertexToDV g v = dvFromIDValue . fromJust . vertexValue g $ v
vertexToDEV :: InfluenceDiagram -> Vertex -> DEV 
vertexToDEV g v = 
  let DV v1 d = vertexToDV g v 
  in 
  DEV v1 d
vertexToUV :: InfluenceDiagram -> Vertex -> UV 
vertexToUV g v = 
  let DV v1 d = vertexToDV g v 
  in 
  UV v1 d
rootDecision :: InfluenceDiagram -> Maybe Vertex 
rootDecision g = do 
    r <- rootNode g 
    if isDecisionNode g r 
      then
        return r 
      else 
        rootDecision (removeVertex r g)
data ChancesOrDecision = C ![DV] | DEC !DEV deriving(Eq,Ord,Show)
dvOrder :: [ChancesOrDecision] -> [DV] 
dvOrder [] = []
dvOrder (C l:r) = l ++ dvOrder r 
dvOrder (DEC d:r) = dv d: dvOrder r 
removeAndRecordRootDecision :: [ChancesOrDecision] -> InfluenceDiagram -> [ChancesOrDecision]
removeAndRecordRootDecision currentL currentG = 
  case vertexToDEV currentG <$> (rootDecision currentG) of 
    Nothing -> (C (remainingChanceNodes currentG)):currentL 
    Just newD -> 
      let (currentG', p) = chanceParents newD currentG
      in
      removeAndRecordRootDecision ((DEC newD):(C p):currentL) currentG'
decisionsOrder :: InfluenceDiagram -> [ChancesOrDecision] 
decisionsOrder g = removeAndRecordRootDecision [] $ g 
maximalizeOneVariable :: Buckets JoinSum -> DV -> (Buckets JoinSum,DecisionFactor)
maximalizeOneVariable currentBucket dv   = 
  let fk = getBucket dv currentBucket
      (newF, instF) = maximalize dv fk
  in
  (updateBucket dv newF currentBucket, instF)
marginalizeID :: [ChancesOrDecision]  -> Buckets JoinSum -> [DecisionFactor] -> (Buckets JoinSum,[DecisionFactor])
marginalizeID [] b r = (b,r)
marginalizeID (C d:r) currentB currentR =  
  let bucket' = foldl' marginalizeOneVariable currentB d 
  in 
  marginalizeID r bucket' currentR 
marginalizeID (DEC de:r) currentB currentR = 
  let (bucket',instF) = maximalizeOneVariable currentB (dv de) 
  in 
  marginalizeID r bucket' (instF:currentR) 
solveInfluenceDiagram :: InfluenceDiagram -> [DecisionFactor]
solveInfluenceDiagram g = 
  let decOrder = decisionsOrder g
      theFactors = chanceAndUtilityFactors g
      p = dvOrder decOrder 
      bucket = createBuckets theFactors p []
      (_, result) = marginalizeID decOrder bucket []
  in
  result