----------------------------------------------------------------------------- -- Copyright 2019, Advise-Me project team. This file is distributed under -- the terms of the Apache License 2.0. For more information, see the files -- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution. ----------------------------------------------------------------------------- -- | -- Maintainer : bastiaan.heeren@ou.nl -- Stability : provisional -- Portability : portable (depends on ghc) -- ----------------------------------------------------------------------------- module Bayes.Inference ( getFactors, inferNetwork, pruneNetwork, infer, simulate, Query , inferEvidence , trimFor , toEvidence , toNetwork ) where import Data.List import Data.Maybe import Bayes.EliminationOrdering import Bayes.Evidence import Bayes.Factor hiding (size) import qualified Bayes.Factor as F import Bayes.Network import Bayes.NodeTypes import Bayes.Probability import qualified Data.Map as M import qualified Data.Set as S -- set of query variables type Query = S.Set String -- list of factors, set query variables qs, and elimination variables vs posteriors :: Query -> [Factor] -> EliminationOrdering -> M.Map String [Probability] posteriors qs list order = case nextVariable order list [] of Nothing -> M.empty Just (v, rest) | v `S.member` qs -> M.insert v ps results | otherwise -> results where ps = query v list rest results = posteriors (S.delete v qs) (eliminate list v) rest query :: String -> [Factor] -> EliminationOrdering -> [Probability] query q list order = case nextVariable order list [q] of Nothing -> map fromDouble (values (normalize (mconcat list))) Just (v, rest) -> query q (if q==v then list else eliminate list v) rest pruneNetwork :: Query -> Network a -> Network a pruneNetwork qs nw = filterNodes ((`S.member` keep) . nodeId) nw where keep = collect S.empty (S.toList qs) collect acc [] = acc collect acc (x:xs) | x `S.member` acc = collect acc xs | otherwise = case findNode nw x of Just n -> collect (S.insert x acc) (parentIds n ++ xs) Nothing -> collect acc xs getFactors :: Network a -> [Factor] getFactors nw = [ nodeToFactor nw n | n <- nodes nw ] cptFactor :: [Node a] -> [Probability] -> Factor cptFactor ns = makeFactor (map (\n -> (nodeId n, size n)) ns) . map toDouble nodeToFactor :: Network a -> Node a -> Factor nodeToFactor nw n = case definition n of CPT xs -> cptFactor (parents nw n ++ [n]) xs NoisyMax str xs -> mkNoisyOr (nodeId n, size n) ps str xs NoisyAdder dst ws xs -> mkNoisyAdder (nodeId n, size n) ps dst ws xs where ps = map (\x -> (nodeId x, size x)) (parents nw n) infer :: Network () -> Evidence -> EliminationOrdering -> Query -> M.Map String [Probability] infer nw0 ev vs qs0 = posteriors qs list vs `M.union` givens where nw = pruneNetwork (qs0 `S.union` S.fromList (map fst cs)) nw0 qs = qs0 S.\\ S.fromList (map fst cs) list = map (conditions cs) (getFactors nw ++ virtualFactors ev) givens = M.fromList $ concatMap f $ fromEvidenceTp ev where f (s, Index i) | s `elem` qs0 = [(s, [ if a == i then 1 else 0 | a <- take n [0..] ])] where n = maybe 0 size (findNode nw s) f _ = [] -- 'cs' is the index-map from the evidence cs = indexMap ev virtualFactors :: Evidence -> [Factor] virtualFactors = concatMap f . fromEvidenceTp where f (s, Virtual ps) = [virtualFactor s (map snd ps)] f (_, Index _) = [] -- | Fill in the probabilities into a network. toNetwork :: Network () -> M.Map String [Probability] -> Network Probability toNetwork nw result = mapNodes f nw where f n = let xs = fromMaybe [] $ M.lookup (nodeId n) result in n { states = zipWith (\(s, _) p -> (s, p)) (states n) xs} -- | Fill in the probabilities into Evidence. toEvidence :: Network () -> M.Map String [Probability] -> Evidence toEvidence nw = getVirtuals . toNetwork nw indexMap :: Evidence -> [(String, Int)] indexMap = map f . fromEvidenceTp where f (s, Index i) = (s, i) f (s, Virtual _) = ('#':s, 0) inferNetwork :: Network () -> Evidence -> EliminationOrdering -> Query -> Network Probability inferNetwork nw ev vs q = toNetwork nw $ infer nw ev vs q inferEvidence :: Network () -> Evidence -> EliminationOrdering -> Query -> Evidence inferEvidence nw ev vs q = toEvidence nw $ infer nw ev vs q simulate :: String -> EliminationOrdering -> [Factor] -> IO () simulate qv eo fs = do printFactors case nextVariable eo fs [qv] of Nothing -> print (map normalize fs) Just (x, eo') -> do putStrLn $ " => " ++ x simulate qv eo' (eliminate fs x) where ns = map F.size fs ss = nub $ concatMap vars fs printFactors = do putStrLn $ "total size: " ++ show (sum ns) putStrLn $ "#factors: " ++ show (length ns) putStrLn $ "#vars: " ++ show (length ss) print ns -- | Remove all 'Evidence' that cannot be fed to the given 'Network'. trimFor :: Network a -> Evidence -> Evidence trimFor nw = filterEvidence (`elem` map nodeId (nodes nw)) getVirtuals :: Network Probability -> Evidence getVirtuals nw = mconcat $ map f (nodes nw) where f n = virtual n (map snd (states n))