{-# OPTIONS_GHC -fno-warn-orphans #-} module Language.Passage.Graph where import qualified Data.IntMap as IM import qualified Data.Map as M import qualified Data.IntSet as IS import Data.List(foldl') -- import Debug.Trace import Language.Passage.AST import Language.Passage.Term import Language.Passage.Utils import Language.Passage.Lang.LaTeX(LaTeX(..)) import qualified Language.Passage.Lang.LaTeX as LaTeX stoPostLL :: StoVar -> Term NodeIdx stoPostLL sv = sum [ b * a | (a,b) <- M.toList (stoPostDistLL sv) ] emptyBayesianGraph :: BayesianGraph emptyBayesianGraph = BayesianGraph { stoNodes = IM.empty , stoArryas = IM.empty } addToStoLL :: NodeIdx -> Term NodeIdx -> BayesianGraph -> BayesianGraph addToStoLL ix t bg = bg { stoNodes = IM.alter addLL ix (stoNodes bg) } where (x,c) = factorVar (fvsArray bg) ix t -- addLL sv = sv { stoPostDistLL = M.insertWith plus x c (stoPostDistLL sv) } addLL (Just sv) = Just $! sv { stoPostDistLL = M.insertWith' plus x c (stoPostDistLL sv) } addLL Nothing = Nothing plus :: (PP a, Eq a, Show a) => Term a -> Term a -> Term a plus a b = {- trace ("plus: " ++ "\n a: " ++ show (pp a) ++ "\n a: " ++ show a ++ "\n b: " ++ show (pp b) ++ "\n b: " ++ show b ++ "\n a+b: " ++ show (pp result)) -} result where result = maybe (a+b) id (sAdd a b) -------------------------------------------------------------------------------- buildBayesianGraph :: BayesianNetwork a -> (a, BayesianGraph) buildBayesianGraph nw = (a, computeLL g) where (a, g) = extractNetwork nw -- | Compute the log-likelihood for a stochastic variable. computeLL :: BayesianGraph -> BayesianGraph computeLL bg = foldl' addDef bg (IM.elems (stoNodes bg)) where addDef m sv = foldl' addSum m (summands (priLL (stoVarPrior sv))) addSum m t = IS.fold (\i m1 -> addToStoLL i t m1) m (leavesOfTerm (fvsArray bg) t) -------------------------------------------------------------------------------- -- Pretty printing -------------------------------------------------------------------------------- data PPVar = PPName String | PPArr String [Int] deriving Show nameToPPName :: BayesianGraph -> StoVar -> PPVar nameToPPName bg sv = case stoVarName sv of Unnamed y -> PPName ("v" ++ show y) Named y -> PPName y InArray a b -> case IM.lookup a (stoArryas bg) of Just ai -> PPArr (arrayName ai) b Nothing -> PPArr ("bug_unknown_array_" ++ show a) b varName :: BayesianGraph -> NodeIdx -> PPVar varName bg x = case IM.lookup x (stoNodes bg) of Just sv -> nameToPPName bg sv Nothing -> case IM.lookup x (stoArryas bg) of Just ai -> PPName (arrayName ai) Nothing -> PPName ("bug_unknown_variable_" ++ show x) namedTerm :: BayesianGraph -> Term NodeIdx -> Term PPVar namedTerm bg = fmap (varName bg) instance PP PPVar where pp (PPName x) = text x pp (PPArr x ys) = text x <> hcat (map (brackets . int) ys) instance LaTeX PPVar where latex (PPName x) = LaTeX.var x latex (PPArr x ys) = LaTeX.var x <> char '_' <> braces (hcat (punctuate comma (map int ys))) instance PP BayesianGraph where pp bg = vcat (map ppSto (IM.elems (stoNodes bg))) where ppT t = pp (namedTerm bg t) ppSto sv = pp (nameToPPName bg sv) <+> text "~~" <+> ppPri (stoVarPrior sv) <+> text ":" <+> ppT (stoPostLL sv) ppPri i = text (priName i) <+> commaSep (map (pp . namedTerm bg) (priParams i)) instance LaTeX BayesianGraph where latex bg = LaTeX.env "tabular" [text "l"] $ vcat $ map (\x -> LaTeX.row [x]) [ LaTeX.env "tabular" [text "l l"] (LaTeX.row [ text "Prior distribution" , text "Posterior log-likelihood" ] $$ vcat (map ppSto (IM.elems (stoNodes bg)))) ] where ppT t = latex (namedTerm bg t) row x y z = LaTeX.row (map LaTeX.math [ x <+> LaTeX.sim <+> y, z]) ppSto sv = row (latex (nameToPPName bg sv)) (ppPri (stoVarPrior sv)) (ppT (stoPostLL sv)) ppPri i = latexDist (priName i) (map (namedTerm bg) (priParams i))