module Bayes.NodeTypes
( virtualFactor, mkNoisyOr, mkNoisyAdder, mkNoisyAdder2
) where
import Bayes.Factor
import Bayes.Probability
import Data.List
import Data.Semigroup
virtualFactor :: String -> [Probability] -> Factor
virtualFactor s ps = makeFactor [(s, length ps), ('#':s, 2)] (concatMap f ps)
where
f x = map toDouble [x, 1-x]
mkNoisyOr :: (String, Int) -> [(String, Int)] -> [Int] -> [Probability] -> Factor
mkNoisyOr target vs str cs = sumouts (map f (reverse vs) ++ ["L"]) (mconcat (mkVars target vs str cs ++ [orGate target vs]))
where
f (s, _) = 'X':s
mkVars :: (a, Int) -> [(String, Int)] -> [Int] -> [Probability] -> [Factor]
mkVars target [] _ cs = [makeFactor [("L", snd target)] (map toDouble cs)]
mkVars target (v:vs) str cs = mkVar v (snd target) (shuffleGroups (snd target) str1 cs1) : mkVars target vs str2 cs2
where
(cs1, cs2) = splitAt (snd target * snd v) cs
(str1, str2) = splitAt (snd v) str
mkVar :: (String, Int) -> Int -> [Probability] -> Factor
mkVar (s, n) nx = makeFactor [(s, n), ('X':s, nx)] . map toDouble
orGate :: (String, Int) -> [(String, Int)] -> Factor
orGate target vs = makeFactor ns cs
where
ns = map h vs ++ [("L", snd target), target]
cs = cartesian g (map minimum (combine (map f (map h vs ++ [target])))) (f target)
f (_, n) = take n [0::Int ..]
g x y = if x == y then 1 else 0
h (s, _) = ('X':s, snd target)
combine :: [[a]] -> [[a]]
combine = foldr (cartesian (:)) [[]]
cartesian :: (a -> b -> c) -> [a] -> [b] -> [c]
cartesian f as bs = [ f a b | a <- as, b <- bs ]
shuffleGroups :: Int -> [Int] -> [a] -> [a]
shuffleGroups n is = concat . shuffle is . groups n
shuffle :: [Int] -> [a] -> [a]
shuffle is as = map (as !!) is
groups :: Int -> [a] -> [[a]]
groups n xs
| null xs = []
| otherwise = xs1 : groups n xs2
where
(xs1, xs2) = splitAt n xs
mkNoisyAdder :: (String, Int) -> [(String, Int)] -> [Int] -> [Double] -> [Probability] -> Factor
mkNoisyAdder target vs dst ws params =
mkNoisyAdder2 target (sort tups) (last ws, last pss)
where
(ss, ns) = unzip vs
pss = splitParameters ns params
tups = zip5 ss ns (tail dst) (init ws) (init pss)
splitParameters :: [Int] -> [Probability] -> [[Probability]]
splitParameters [] cs = [cs]
splitParameters (y:ys) cs =
let (cs1, cs2) = splitAt (y * snd target) cs
in cs1 : splitParameters ys cs2
mkNoisyAdder2 :: (String, Int) -> [(String, Int, Int, Double, [Probability])] -> (Double, [Probability]) -> Factor
mkNoisyAdder2 target@(_, tn) tups (lw, lps) =
makeFactor (vs ++ [target]) $ map fromAvg result
where
vs = map (\(s, n, _, _, _) -> (s, n)) tups
result = snd (foldr add start tups)
start = (1, zipWith avg lps (replicate tn lw))
add :: (String, Int, Int, Double, [Probability]) -> (Int, [Avg]) -> (Int, [Avg])
add (_, n, d, w, fc) (cum, ps) = (cum*n, zipWith (<>) (alignTable cum tn qs) (concat (replicate n ps)))
where
qs = zipWith avg fc (mkWeights n d w)
mkWeights :: Int -> Int -> Double -> [Double]
mkWeights n d w = [ if x==d then 0 else w | x <- [0..n-1], _ <- [1..tn] ]
alignTable :: Int -> Int -> [a] -> [a]
alignTable a b xs = concat $ concatMap (replicate a) (groups b xs)
data Avg = Avg !Double !Double
instance Show Avg where
show = show . fromAvg
instance Semigroup Avg where
Avg a1 w1 <> Avg a2 w2 = Avg (a1+a2) (w1+w2)
instance Monoid Avg where
mempty = Avg 0 0
mappend = (<>)
avg :: Probability -> Double -> Avg
avg x y = Avg (toDouble x*y) y
fromAvg :: Avg -> Double
fromAvg (Avg x y) = x / y