module Bayes.BayesianNetwork(
  
    BNMonad
  , runBN 
  , evalBN
  , execBN
  , Distribution(..)
  
  , variable
  , unamedVariable
  , variableWithSize
  , tdv
  , t
  
  , cpt
  , proba
  , (~~)
  , softEvidence
  , se
  
  , logical 
  , (.==.)
  , (.!.)
  , (.|.)
  , (.&.)
  
  , noisyOR
  ) where
import Bayes
import Bayes.PrivateTypes
import Control.Monad.State.Strict
import Bayes.Factor
import Data.Maybe(fromJust)
import qualified Data.List as L(find)
import Data.List(sort,intercalate,nub)
import Bayes.Tools(minBoundForEnum,maxBoundForEnum,intValue)
import Bayes.Network 
t = undefined
type BNMonad g f a = NetworkMonad g () f a
(~~) :: (DirectedGraph g, Factor f, Distribution d, BayesianDiscreteVariable v) 
     => BNMonad g f v 
     -> d 
     -> BNMonad g f ()
(~~) mv l = do 
  (DV v _) <- mv >>= return . dv 
  maybeNewValue <- getCpt v l
  currentValue <- getBayesianNode v
  case (currentValue, maybeNewValue) of 
    (Just c, Just n) -> initializeNodeWithValue v c n
    _ -> return ()
cpt :: (DirectedGraph g , BayesianDiscreteVariable v,BayesianDiscreteVariable vb) => v -> [vb] -> BNMonad g f v
cpt node conditions = do
  mapM_ ((dv node) <--) (reverse (map dv conditions))
  return node
proba :: (DirectedGraph g, BayesianDiscreteVariable v) => v -> BNMonad g f v
proba node = cpt node ([] :: [DV])
softEvidence :: (NamedGraph g, DirectedGraph g, Factor f) 
             => TDV Bool 
             -> BNMonad g f (TDV Bool) 
             
softEvidence d = do 
  se <- unNamedVariableWithSize (dimension d) 
  
  cpt se [dv d] ~~ [1.0,0.0,1.0,0.0]
  
  return (tdv se) 
se :: Factor f 
   => TDV s 
   -> TDV s 
   -> Double 
   -> Maybe f
se s orgNode p = factorWithVariables [dv s,dv orgNode] [p,1p,1p,p]
data LE = LETest DVI
        | LEAnd LE LE 
        | LEOr LE LE 
        | LENot LE 
        deriving(Eq)
varsFromLE :: LE -> [DV]
varsFromLE le = nub $ _getVars le 
 where 
  _getVars  (LETest dvi) = [dv dvi] 
  _getVars (LEAnd a b) = _getVars a ++ _getVars b
  _getVars (LEOr a b) = _getVars a ++ _getVars b
  _getVars (LENot a) = _getVars a
boolValue :: Maybe Bool -> Bool 
boolValue (Just True) = True 
boolValue _ = False
functionFromLE :: LE -> ([DVI] -> Bool)
functionFromLE (LETest dvi) = \i -> boolValue $ do 
  var <- L.find (== dvi) i
  return (instantiationValue dvi == instantiationValue var)
functionFromLE (LENot l) = \i -> not (functionFromLE l i)
functionFromLE (LEAnd la lb) = \i -> (functionFromLE la i) && (functionFromLE lb i)
functionFromLE (LEOr la lb) = \i -> (functionFromLE la i) || (functionFromLE lb i)
class Testable d v where 
  
  
  (.==.) :: d -> v -> LE 
instance Instantiable d v => Testable d v where 
  (.==.) a b = LETest (a =: b)
infixl 8 .==.
infixl 6 .&.
infixl 5 .|.
(.|.) :: LE -> LE -> LE
(.|.)  = LEOr 
(.&.) :: LE -> LE -> LE
(.&.) = LEAnd
(.!.) :: LE -> LE
(.!.) = LENot
logical :: (Factor f, DirectedGraph g) => TDV Bool -> LE -> BNMonad g f () 
logical dv l = 
  let theVars = varsFromLE l
      logicalF = functionFromLE l 
      probaVal True = 1.0 :: Double
      probaVal False = 0.0 :: Double
      valuesF = [probaVal (logicalF i == False) | i <-forAllInstantiations (DVSet theVars)]
      valuesT = [probaVal (logicalF i == True) | i <-forAllInstantiations (DVSet theVars)]
  in 
  cpt dv theVars ~~ (valuesF ++ valuesT)
noisyAND :: (DirectedGraph g, Factor f, NamedGraph g) => TDV Bool -> Double -> BNMonad g f (TDV Bool) 
noisyAND a p = do 
    na <- unamedVariable (t::Bool)
    cpt na [dv a] ~~ [1p,p,p,1p]
    return na 
orG :: (DirectedGraph g, Factor f, NamedGraph g) => TDV Bool -> TDV Bool -> BNMonad g f (TDV Bool)
orG a b = do 
    no <- unamedVariable (t::Bool)
    logical no ((a .==. True) .|. (b .==. True))
    return no 
noisyOR :: (DirectedGraph g, Factor f, NamedGraph g) 
        => [(TDV Bool,Double)] 
        -> BNMonad g f (TDV Bool) 
noisyOR l = do 
    a <- mapM (\(a,p) -> noisyAND a p) l
    foldM orG (head a) (tail a)
runBN :: BNMonad DirectedSG f a -> (a,SBN f)
runBN = runNetwork
execBN :: BNMonad DirectedSG f a -> SBN f
execBN = execNetwork
evalBN :: BNMonad DirectedSG f a -> a
evalBN = evalNetwork