-- | Gibbs sampling with a Naive Bayes model -- Details and notation based on: -- http://www.cs.umd.edu/~hardisty/papers/gsfu.pdf import Control.Applicative import Control.Monad import Control.Monad.Primitive import System.Random.MWC hiding (initialize) import Statistics.Distribution import Statistics.Distribution.Beta import Statistics.Distribution.Gamma import qualified Data.Vector.Unboxed as U import qualified Data.Vector as V import qualified Data.Map as M type HyperParams = U.Vector Int type Theta = U.Vector Double -- Distribution over words in a document type Doc = [String] type Label = Bool type WordCounts = M.Map String Int data Point = Point { observe :: Doc, label :: Label } deriving (Show) -- A point "augmented" with word counts information data AugPoint = PnC { point :: Point, counts :: WordCounts } deriving (Show) type Corpus = V.Vector AugPoint data Info = Info { dataSet :: Corpus , thetas :: M.Map Label Theta , wcInfo :: M.Map Label WordCounts , numDocs :: M.Map Label Int } type Sample = V.Vector Label -- Beta(1,1), i.e, uniform beta :: Gen RealWorld -> IO Double beta = genContVar (betaDistr 1 1) -- Dirichlet as a vector of samples from Gamma dirichlet :: HyperParams -> Gen RealWorld -> IO Theta dirichlet hps gen = do let gamma_draw hp = genContVar (gammaDistr (fromIntegral hp) 1) gen ys <- U.mapM gamma_draw hps let y_sum = U.sum ys return $ U.map (/y_sum) ys bernoulli :: Double -> Gen RealWorld -> IO Bool bernoulli p gen = uniform gen >>= return . (>p) tokens :: V.Vector String tokens = V.fromList.words $ "I went to the square and saw a foolish \ \politician who would not stop blabbering" capV :: Int capV = V.length tokens capN :: Int capN = 20 -- Docsize kmax :: Int kmax = 36 -- Generate a "bag of words" gen_bag :: Gen RealWorld -> IO Doc gen_bag g = let bagger b 0 = return b bagger b n = do i <- uniformR (0,capV-1) g bagger ((tokens V.! i):b) $ n-1 in uniformR (1,kmax) g >>= bagger [] doc_counts :: Point -> WordCounts doc_counts (Point doc _) = foldr (M.adjust (+1)) zeroes doc where ts = V.toList tokens zeroes = M.fromList $ zip ts [0..] -- Assume an ordering of [a], in this case [true-value, false-value] label_map :: [a] -> M.Map Label a label_map vl = M.fromList $ zip [True, False] vl initialize :: Gen RealWorld -> IO Info initialize g = do p <- beta g let gen_point = Point <$> gen_bag g <*> bernoulli p g points <- V.fromList <$> replicateM capN gen_point let corpus = V.zipWith PnC points $ V.map doc_counts points (trues, falses) = V.unstablePartition (label.point) corpus collect_counts ps = V.foldl1 (M.unionWith (+)) $ V.map counts ps wcMap = label_map $ map collect_counts [trues, falses] nums = label_map [V.length trues, V.length falses] thets <- label_map <$> (replicateM 2 $ dirichlet (U.replicate capV 1) g) return $ Info corpus thets wcMap nums cond_prob :: Int -> Theta -> WordCounts -> Double cond_prob c_x theta_x wc_j = let prod i t = (*) $ (^) t (wc_j M.! (tokens V.! i)) p = U.ifoldr prod 1.0 theta_x c = (/) (fromIntegral c_x) (fromIntegral $ capN + 1) in c * p sample_label :: Int -> Info -> Gen RealWorld -> IO Label sample_label j (Info dat thetas _ nums) gen = do let wc_j = counts $ dat V.! j pTrue = cond_prob (nums M.! True) (thetas M.! True) wc_j pFalse = cond_prob (nums M.! False) (thetas M.! False) wc_j pNorm = (/) pTrue $ pTrue + pFalse bernoulli pNorm gen assign_label :: Int -> Label -> Info -> Info assign_label j lab (Info d t w n) = let ap = d V.! j p = point ap new_ap = PnC (Point (observe p) lab) (counts ap) new_d = (V.//) d [(j,new_ap)] in Info new_d t w n type WCUpdate = WordCounts -> WordCounts update_wc :: WCUpdate -> Label -> Info -> Info update_wc fun lab (Info d t wc n) = Info d t (M.adjust fun lab wc) n type NumUpdate = Int -> Int update_num :: NumUpdate -> Label -> Info -> Info update_num fun lab (Info d t w nums) = Info d t w (M.adjust fun lab nums) sampler :: Gen RealWorld -> Info -> IO Info sampler gen info = do let ind_ds = V.indexed $ dataSet info f acc (j,ap) = do let lab = (label . point) ap sub_fun = M.differenceWith (\b a -> Just (a-b)) (counts ap) pre_sample_info = update_num (flip (-) 1) lab $ update_wc sub_fun lab acc new_lab <- sample_label j pre_sample_info gen let post_sample_info = assign_label j new_lab pre_sample_info return $ update_wc (M.unionWith (+) (counts ap)) new_lab $ update_num (+1) new_lab post_sample_info V.foldM f info ind_ds -- TODO: Check whether foldM goes left-to-right new_thetas :: Gen RealWorld -> Info -> IO Info new_thetas gen (Info d _ w n) = do let f wc i = (+1) $ wc M.! (tokens V.! i) hyperparams = M.map (U.generate capV) $ M.map f w thetaT <- dirichlet (hyperparams M.! True) gen thetaF <- dirichlet (hyperparams M.! False) gen return $ Info d (label_map [thetaT, thetaF]) w n capT :: Int capT = 1000 gibbs :: Gen RealWorld -> IO Sample gibbs g = do let loop 0 info = return info loop t info = sampler g info >>= new_thetas g >>= loop (t-1) info <- initialize g >>= loop capT return $ V.map (label.point) (dataSet info) main :: IO () main = testGibbs -- Tests testPRNG :: IO () testPRNG = do gen <- createSystemRandom beta gen >>= print beta gen >>= print testTheta :: IO () testTheta = createSystemRandom >>= dirichlet (U.replicate capV 1) >>= print testBag :: IO () testBag = createSystemRandom >>= gen_bag >>= print testInit :: IO () testInit = createSystemRandom >>= initialize >>= print.(V.map counts).dataSet testCondProb :: IO () testCondProb = do info <- createSystemRandom >>= initialize let pTrue = cond_prob (numDocs info M.! True) (thetas info M.! True) $ counts $ dataSet info V.! 0 putStrLn $ "P(L_0=True | Initial) = " ++ show pTrue testGibbs :: IO () testGibbs = createSystemRandom >>= gibbs >>= print