{-# LANGUAGE MagicHash #-} ----------------------------------------------------------------------------- -- Copyright 2019, Advise-Me project team. This file is distributed under -- the terms of the Apache License 2.0. For more information, see the files -- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution. ----------------------------------------------------------------------------- -- | -- Maintainer : bastiaan.heeren@ou.nl -- Stability : provisional -- Portability : portable (depends on ghc) -- ----------------------------------------------------------------------------- module Bayes.Factor ( -- * Factor type and dimensions Factor , Dimensions, dimensions, fromDimensions, hasVarD, mergesD , Dimensional(..) -- * Construction , makeFactor -- * Simple queries , values -- * Standard operations , multiply, sumout, sumouts, condition, conditions, normalize -- * Variable elimination on factors , eliminate, eliminateSplit, eliminateList ) where import Data.List import Data.Maybe import Data.Semigroup import GHC.Exts import qualified Data.Set as S data Factor = F { dimensions :: Dimensions, getTree :: Tree } instance Show Factor where show (F d t) = intercalate "," (map showDim (fromDimensions d)) ++ ": " ++ show t where showDim (s, n) = s ++ "[" ++ show n ++ "]" instance Semigroup Factor where (<>) = multiply instance Monoid Factor where mempty = makeFactor [] [1] mappend = (<>) mconcat = multiplyList makeFactor :: [(String, Int)] -> [Double] -> Factor makeFactor d bs = F (mkD dims) (reorder dims d $ makeTree (map snd d) bs) where dims = sort d makeTree :: [Int] -> [Double] -> Tree makeTree is = rec (reverse is) . map leaf where rec (2:ns) ts = rec ns $ groups2 ts rec (3:ns) ts = rec ns $ groups3 ts rec (n:ns) ts = rec ns $ groupsN n ts rec [] [t] = t rec _ _ = error "makeFactor" {-# INLINE groups2 #-} groups2 :: [Tree] -> [Tree] groups2 [] = [] groups2 (x:y:rest) = Two x y : groups2 rest groups2 _ = error "groups2" {-# INLINE groups3 #-} groups3 :: [Tree] -> [Tree] groups3 [] = [] groups3 (x:y:z:rest) = Three x y z : groups3 rest groups3 _ = error "groups3" {-# INLINE groupsN #-} groupsN :: Int -> [Tree] -> [Tree] groupsN n = rec where rec xs | null xs = [] | otherwise = Bin xs1 : rec xs2 where (xs1, xs2) = splitAt n xs condition :: String -> Int -> Factor -> Factor condition s i (F d t) = F (deleteD s d) (set s i d t) conditions :: [(String, Int)] -> Factor -> Factor conditions xs (F d t) = F (deletesD (map fst xs) d) (choose xs d t) multiplyList :: [Factor] -> Factor multiplyList = rec where rec [] = mempty rec [a] = a rec [a,b] = multiply a b rec as = rec (best:rest) where m:ms = sortBy cmp as cmp x y = size x `compare` size y (best, rest) = minimumBy (\x y -> size (fst x) `compare` size (fst y)) $ map make [0 .. length ms-1] make i = (multiply m y, xs ++ ys) where (xs, y:ys) = splitAt i ms multiply :: Factor -> Factor -> Factor multiply (F d1 t1) (F d2 t2) = F (mergeD d1 d2) (mergeWith (*##) d1 d2 t1 t2) sumouts :: [String] -> Factor -> Factor sumouts vs0 (F dt t0) = F (deletesD vs0 dt) (rec (dropTail bools) t0) where ds = map fst (fromDimensions dt) bools = f (filter (`elem` ds) vs0) ds f [] _ = [] f vs (x:xs) | x `elem` vs = True : f (delete x vs) xs | otherwise = False : f vs xs f _ _ = [] rec [] t = t rec (b:bs) t | b = foldr1 (zipTree (+##)) (subtrees rt) | otherwise = rt where rt = mapSubtrees (rec bs) t dropTail = reverse . dropWhile not . reverse sumout :: String -> Factor -> Factor sumout = sumouts . return values :: Factor -> [Double] values = treeToList . getTree normalize :: Factor -> Factor normalize x = case sum (values x) of D# total -> x { getTree = mapTree (/## total) (getTree x) } ----------------------------------------------------------------------------- eliminate :: [Factor] -> String -> [Factor] eliminate fs = uncurry f . eliminateSplit fs where f x xs = if size x <= 1 then xs else x : xs eliminateList :: [Factor] -> [String] -> [Factor] eliminateList = foldl eliminate -- | Eliminate variable: returns new factor and remaining (unchanged) factors eliminateSplit :: [Factor] -> String -> (Factor, [Factor]) eliminateSplit fs s = (sumout s (mconcat fs1), fs2) where (fs1, fs2) = partition (S.member s . varSet) fs ----------------------------------------------------------------------------- class Dimensional a where size :: a -> Int varSet :: a -> S.Set String vars :: a -> [String] -- default vars = S.toList . varSet instance Dimensional Factor where size = size . dimensions varSet = varSet . dimensions instance Dimensional Dimensions where size (D n _) = n varSet = S.fromList . map fst . fromDimensions instance Dimensional a => Dimensional [a] where size = sum . map size varSet = S.unions . map varSet --------------------------------------------------------- data Dimensions = D Int [(String, Int)] fromDimensions :: Dimensions -> [(String, Int)] fromDimensions (D _ xs) = xs mkD :: [(String, Int)] -> Dimensions mkD xs = D (product (map snd xs)) xs deleteD :: String -> Dimensions -> Dimensions deleteD s = filterD (/= s) deletesD :: [String] -> Dimensions -> Dimensions deletesD xs = filterD (`notElem` xs) filterD :: (String -> Bool) -> Dimensions -> Dimensions filterD p (D _ m) = mkD (filter (p . fst) m) hasVarD :: String -> Dimensions -> Bool hasVarD s (D _ xs) = rec xs where rec [] = False rec ((x, _):rest) = case compare s x of LT -> False EQ -> True GT -> rec rest mergeD :: Dimensions -> Dimensions -> Dimensions mergeD (D _ m1) (D _ m2) = mkD (rec m1 m2) where rec lx@(x:xs) ly@(y:ys) = case compare x y of LT -> x : rec xs ly EQ -> x : rec xs ys GT -> y : rec lx ys rec xs ys = xs ++ ys mergesD :: [Dimensions] -> Dimensions mergesD = foldr1 mergeD ----------------------------------------------------------------------------- data Tree = Bin [Tree] | Leaf Double# | Two !Tree !Tree | Three !Tree !Tree !Tree bin :: [Tree] -> Tree bin [x, y] = Two x y bin [x, y, z] = Three x y z bin xs = Bin xs leaf :: Double -> Tree leaf (D# x) = Leaf x instance Show Tree where show (Leaf a) = show (D# a) show t = "(" ++ intercalate "," (map show (subtrees t)) ++ ")" {-# INLINE mapTree #-} mapTree :: (Double# -> Double#) -> Tree -> Tree mapTree f = rec where rec (Leaf a) = Leaf (f a) rec t = mapSubtrees rec t treeToList :: Tree -> [Double] treeToList = ($ []) . rec where rec (Bin ts) = foldr ((.) . rec) id ts rec (Leaf a) = (D# a:) rec (Two x y) = rec x . rec y rec (Three x y z) = rec x . rec y . rec z {-# INLINE zipTree #-} zipTree :: (Double# -> Double# -> Double#) -> Tree -> Tree -> Tree zipTree f = rec where rec (Leaf a) (Leaf b) = Leaf (f a b) rec t1 t2 = zipSubtrees rec t1 t2 {-# INLINE subtrees #-} subtrees :: Tree -> [Tree] subtrees (Bin xs) = xs subtrees (Two x y) = [x, y] subtrees (Three x y z) = [x, y, z] subtrees _ = error "subtrees" {-# INLINE subtree #-} subtree :: Int -> Tree -> Tree subtree i (Bin xs) = xs !! i subtree 0 (Two x _) = x subtree 1 (Two _ y) = y subtree 0 (Three x _ _) = x subtree 1 (Three _ y _) = y subtree 2 (Three _ _ z) = z subtree _ _ = error "subtree" {-# INLINE mapSubtrees #-} mapSubtrees :: (Tree -> Tree) -> Tree -> Tree mapSubtrees f (Bin xs) = Bin (map f xs) mapSubtrees f (Two x y) = Two (f x) (f y) mapSubtrees f (Three x y z) = Three (f x) (f y) (f z) mapSubtrees _ _ = error "mapSubtrees" {-# INLINE zipSubtrees #-} zipSubtrees :: (Tree -> Tree -> Tree) -> Tree -> Tree -> Tree zipSubtrees f (Bin xs) (Bin ys) = Bin (zipWith f xs ys) zipSubtrees f (Two x1 x2) (Two y1 y2) = Two (f x1 y1) (f x2 y2) zipSubtrees f (Three x1 x2 x3) (Three y1 y2 y3) = Three (f x1 y1) (f x2 y2) (f x3 y3) zipSubtrees _ _ _ = error "zipSubtrees" {-# INLINE mergeWith #-} mergeWith :: (Double# -> Double# -> Double#) -> Dimensions -> Dimensions -> Tree -> Tree -> Tree mergeWith f da db = rec (merges (fromDimensions da) (fromDimensions db)) where rec (TakeLeft:ms) t1 t2 = mapSubtrees (\x -> rec ms x t2) t1 rec (TakeRight:ms) t1 t2 = mapSubtrees (\x -> rec ms t1 x) t2 rec (Merge:ms) t1 t2 = zipSubtrees (rec ms) t1 t2 rec [] t1 t2 = zipTree f t1 t2 merges :: [(String, Int)] -> [(String, Int)] -> [Merge] merges [] [] = [] merges (_:ds1) [] = TakeLeft : merges ds1 [] merges [] (_:ds2) = TakeRight : merges [] ds2 merges l1@(d1:ds1) l2@(d2:ds2) = case compare d1 d2 of LT -> TakeLeft : merges ds1 l2 EQ -> Merge : merges ds1 ds2 GT -> TakeRight : merges l1 ds2 data Merge = TakeLeft | TakeRight | Merge set :: String -> Int -> Dimensions -> Tree -> Tree set s i = choose [(s, i)] choose :: [(String, Int)] -> Dimensions -> Tree -> Tree choose env d = rec (dropTail choices) where rec [] t = t rec (Nothing:cs) t = mapSubtrees (rec cs) t rec (Just i:cs) t = rec cs (subtree i t) choices = map (\(s, _) -> lookup s env) (fromDimensions d) dropTail = reverse . dropWhile isNothing . reverse setAtLevel :: Int -> Int -> Tree -> Tree setAtLevel lev i = rec lev where rec 0 = subtree i rec l = mapSubtrees (rec (l-1)) reorder :: [(String, Int)] -> [(String, Int)] -> Tree -> Tree reorder as bs t | as == bs = t reorder (p:rest) (o:old) t | p==o = mapSubtrees (reorder rest old) t reorder ((s, n):rest) old t = case findIndex ((== s) . fst) old of Just l -> bin [ reorder rest (filter ((/= s) . fst) old) $ setAtLevel l i t | i <- [0..n-1] ] Nothing -> error "invalid reordering" reorder _ _ _ = error "invalid reordering"