-- | Sampler for Gaussian Mixture Model -- -- Here is the code in the Hakaru language for generating -- the data used in this example: -- -- > p <- unconditioned (beta 2 2) -- > [m1,m2] <- replicateM 2 $ unconditioned (normal 100 30) -- > [s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2) -- > let makePoint = do -- > b <- unconditioned (bern p) -- > unconditioned (ifThenElse b (normal m1 s1) -- > (normal m2 s2)) -- > replicateM nPoints makePoint module MCMC.Examples.GMM ( GaussianMixtureState (..) -- * Target -- ** Focus combinators -- $focuscombs -- , focusLabels, focusGaussParams, focusBernParam, focusObs -- ** Record field targets -- $fieldtargets -- , labelsTarget, gaussParamsTarget, bernParamTarget, obsTarget -- ** Target density factors -- $targetfactors -- , labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor -- ** Target density -- $tdens -- , gmmTarget -- * Proposal -- ** Proposal update boilerplate -- $proposalfocus -- , updateLabels, updateGaussParams, updateBernParam -- ** Field proposals -- $fieldproposals -- , labelsProposal, gaussParamsProposal, bernParamProposal -- ** Field updaters -- $fieldupdaters -- , labelsUpdater, gaussParamsUpdater, bernParamUpdater -- ** The combined proposal -- $gmmprop -- , gmmProposal -- * Running the sampler -- ** Transition kernel -- $kernel -- , gmmMH -- ** Visualization methods -- $visual -- , histogram, printFields, printLabelN, compareLabels, printHist, batchHist -- ** Main -- $main -- , nPoints, sampleData, gmmStart, gmmTest ) where import MCMC.Combinators import MCMC.Distributions import MCMC.Kernels import MCMC.Actions import MCMC.Types import qualified System.Random.MWC as MWC import qualified Data.Map.Strict as Map import Control.Monad data GaussianMixtureState = GMM { labels :: [Bool] -- ^ The list of observation labels , gaussParams :: ((Double, Double), (Double, Double)) -- ^ The parameters of the two Gaussians (mean, covariance) , bernParam :: Double -- ^ The mixture proportion , obs :: [Double] -- ^ The observed data } -- Target ---------- -- Focus combinators focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState focusLabels t = makeTarget dens where dens (GMM l _ p _) = density t (p,l) focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState focusGaussParams t = makeTarget (density t . gaussParams) focusBernParam :: Target Double -> Target GaussianMixtureState focusBernParam t = makeTarget (density t . bernParam) focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) -> Target GaussianMixtureState focusObs t = makeTarget dens where dens (GMM l gps _ o) = density t (l, gps, o) -- Record field targets labelsTarget :: Target (Double, [Bool]) labelsTarget = makeTarget $ \(p,ls) -> product $ map (density $ bern p) ls gaussParamsTarget :: Target ((Double, Double), (Double, Double)) gaussParamsTarget = makeTarget dens where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2 mdens m = density (normal 100 900) m cdens c = density (uniform 0 200) c bernParamTarget :: Target Double bernParamTarget = fromProposal (beta 2 2) obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) obsTarget = makeTarget dens where dens (ls, ((m1, c1), (m2, c2)), os) = let ols = zip os ls gauss l = if l then normal m1 (c1*c1) else normal m2 (c2*c2) in product $ map (\(o,l) -> density (gauss l) o) ols -- Target density factors labelsFactor :: Target GaussianMixtureState labelsFactor = focusLabels labelsTarget gaussParamsFactor :: Target GaussianMixtureState gaussParamsFactor = focusGaussParams gaussParamsTarget bernParamFactor :: Target GaussianMixtureState bernParamFactor = focusBernParam bernParamTarget obsFactor :: Target GaussianMixtureState obsFactor = focusObs obsTarget -- Target density gmmTarget :: Target GaussianMixtureState gmmTarget = makeTarget $ productDensity [labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor] nPoints :: Int nPoints = 6 sampleData :: [Double] sampleData = [ 63.13941114139962, 132.02763712240528 , 62.59642260289356, 132.2616834236893 , 64.10610391933461, 62.143820541377934 ] gmmTargetDensityTest :: IO () gmmTargetDensityTest = do let sampleParams = ((63, 100), (132, 100)) b = 0.5 makeState sampleLabels = GMM sampleLabels sampleParams b sampleData labels1 = replicate nPoints False labels2 = map not labels1 labels3 = [True, False, True, False, True, True] labels4 = [True, True, True, False, False, False] putStr $ show labels1 ++ " : " print $ density gmmTarget $ makeState labels1 putStr $ show labels2 ++ " : " print $ density gmmTarget $ makeState labels2 putStr $ show labels3 ++ " : " print $ density gmmTarget $ makeState labels3 putStr $ show labels4 ++ " : " print $ density gmmTarget $ makeState labels4 -- Proposal ---------- -- Field update combinators updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState updateLabels f x = makeProposal dens sf where dens y = density (f $ labels x) (labels y) sf g = do newLabels <- sampleFrom (f $ labels x) g return x { labels = newLabels } updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))) -> GaussianMixtureState -> Proposal GaussianMixtureState updateGaussParams f x = makeProposal dens sf where dens y = density (f $ gaussParams x) (gaussParams y) sf g = do newParams <- sampleFrom (f $ gaussParams x) g return x { gaussParams = newParams } updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState updateBernParam f x = makeProposal dens sf where dens y = density (f $ bernParam x) (bernParam y) sf g = do newParam <- sampleFrom (f $ bernParam x) g return x { bernParam = newParam } -- Field proposals labelsProposal :: [Bool] -> Proposal [Bool] labelsProposal ls = chooseProposal nPoints (\n -> updateNth n flipBool ls) where flipBool bn = if bn then bern 0 else bern 1 gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) gaussParamsProposal params = mixProposals $ zip [m1p, c1p, m2p, c2p] (repeat 1) where condProp c = normal c 1 m1p = updateFirst (updateFirst condProp) params c1p = updateFirst (updateSecond condProp) params m2p = updateSecond (updateFirst condProp) params c2p = updateSecond (updateSecond condProp) params bernParamProposal :: Double -> Proposal Double bernParamProposal p = uniform (p/2) (1-p/2) -- Field updaters labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState labelsUpdater = updateLabels labelsProposal gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState gaussParamsUpdater = updateGaussParams gaussParamsProposal bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState bernParamUpdater = updateBernParam bernParamProposal -- GMM Proposal gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState gmmProposal = mixCondProposals $ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2] -- Histogram and other visualizations histogram :: Ord a => [a] -> Map.Map a Int histogram ls = foldl addElem Map.empty ls where addElem m e = Map.insertWith (+) e 1 m printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double) printFields = let f s = (labels s, gaussParams s, bernParam s) in map f printLabelN :: Int -> PrintF GaussianMixtureState Bool printLabelN n = let f s = labels s !! (n-1) in map f compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool) compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO () printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO () batchHist f n = pack (printHist f) $ inBatches (printHist f) n -- Kernel ---------- gmmMH :: Step GaussianMixtureState gmmMH = metropolisHastings gmmTarget gmmProposal gmmStart :: GaussianMixtureState gmmStart = GMM { labels = [True, True, True, False, False, False], gaussParams = ((63, 100), (132, 100)), bernParam = 0.5, obs = sampleData } gmmTest :: IO () gmmTest = do g <- MWC.createSystemRandom let a = batchHist (compareLabels 5 6) 50 e = every 50 a c = every 50 collect ls <- walk gmmMH gmmStart (10^6) g c putStrLn "Done" print $ take 20 (map labels ls) -- Older, simpler way of writing target density ---------- probLabels :: GaussianMixtureState -> Double probLabels (GMM l _ p _) = product $ map (\b -> if b then p else 1-p) l probObs :: GaussianMixtureState -> Double probObs state = product $ map (\(o,l) -> density (gauss l) o) ols where ols = zip (obs state) (labels state) ((m1, c1), (m2, c2)) = gaussParams state gauss l = if l then normal m1 c1 else normal m2 c2 probGaussParams :: GaussianMixtureState -> Double probGaussParams state = mdens m1 * mdens m2 * cdens c1 * cdens c2 where ((m1, c1), (m2, c2)) = gaussParams state mdens m = density (normal 100 900) m cdens c = density (uniform 0 2) c probBernParam :: GaussianMixtureState -> Double probBernParam state = density (beta 2 2) (bernParam state) gmmTargetOld :: Target GaussianMixtureState gmmTargetOld = makeTarget dens where dens s = probLabels s * probObs s * probGaussParams s * probBernParam s ----------------- -- Documentation ----------------- -- $focuscombs -- @ -- focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState -- focusLabels t = 'makeTarget' dens -- where dens (GMM l _ p _) = 'density' t (p,l) -- -- focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState -- focusGaussParams t = 'makeTarget' ('density' t . gaussParams) -- -- focusBernParam :: Target Double -> Target GaussianMixtureState -- focusBernParam t = 'makeTarget' ('density' t . bernParam) -- -- focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) -- -> Target GaussianMixtureState -- focusObs t = 'makeTarget' dens -- where dens (GMM l gps _ o) = 'density' t (l, gps, o) -- @ -- $fieldtargets -- @ -- labelsTarget :: Target (Double, [Bool]) -- labelsTarget = 'makeTarget' $ \(p,ls) -> product $ map ('density' $ 'bern' p) ls -- -- gaussParamsTarget :: Target ((Double, Double), (Double, Double)) -- gaussParamsTarget = 'makeTarget' dens -- where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2 -- mdens m = 'density' ('normal' 100 900) m -- cdens c = 'density' ('uniform' 0 200) c -- -- bernParamTarget :: Target Double -- bernParamTarget = 'fromProposal' ('beta' 2 2) -- -- obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) -- obsTarget = 'makeTarget' dens -- where dens (ls, ((m1, c1), (m2, c2)), os) -- = let ols = zip os ls -- gauss l = if l then 'normal' m1 (c1*c1) else 'normal' m2 (c2*c2) -- in product $ map (\(o,l) -> 'density' (gauss l) o) ols -- @ -- $targetfactors -- @ -- labelsFactor :: Target GaussianMixtureState -- labelsFactor = focusLabels labelsTarget -- -- gaussParamsFactor :: Target GaussianMixtureState -- gaussParamsFactor = focusGaussParams gaussParamsTarget -- -- bernParamFactor :: Target GaussianMixtureState -- bernParamFactor = focusBernParam bernParamTarget -- -- obsFactor :: Target GaussianMixtureState -- obsFactor = focusObs obsTarget -- @ -- $tdens -- @ -- gmmTarget :: Target GaussianMixtureState -- gmmTarget = 'makeTarget' $ 'productDensity' -- [labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor] -- @ -- $proposalfocus -- @ -- updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState -- updateLabels f x = 'makeProposal' dens sf -- where dens y = 'density' (f $ labels x) (labels y) -- sf g = do newLabels <- 'sampleFrom' (f $ labels x) g -- return x { labels = newLabels } -- -- updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))) -- -> GaussianMixtureState -> Proposal GaussianMixtureState -- updateGaussParams f x = 'makeProposal' dens sf -- where dens y = 'density' (f $ gaussParams x) (gaussParams y) -- sf g = do newParams <- 'sampleFrom' (f $ gaussParams x) g -- return x { gaussParams = newParams } -- -- updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState -- updateBernParam f x = 'makeProposal' dens sf -- where dens y = 'density' (f $ bernParam x) (bernParam y) -- sf g = do newParam <- 'sampleFrom' (f $ bernParam x) g -- return x { bernParam = newParam } -- @ -- $fieldproposals -- @ -- labelsProposal :: [Bool] -> Proposal [Bool] -- labelsProposal ls = 'chooseProposal' nPoints (\n -> 'updateNth' n flipBool ls) -- where flipBool bn = if bn then 'bern' 0 else 'bern' 1 -- -- gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) -- gaussParamsProposal params = 'mixProposals' $ zip [m1p, c1p, m2p, c2p] (repeat 1) -- where condProp c = 'normal' c 1 -- m1p = 'updateFirst' ('updateFirst' condProp) params -- c1p = 'updateFirst' ('updateSecond' condProp) params -- m2p = 'updateSecond' ('updateFirst' condProp) params -- c2p = 'updateSecond' ('updateSecond' condProp) params -- -- bernParamProposal :: Double -> Proposal Double -- bernParamProposal p = 'uniform' (p/2) (1-p/2) -- @ -- $fieldupdaters -- @ -- labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState -- labelsUpdater = updateLabels labelsProposal -- -- gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState -- gaussParamsUpdater = updateGaussParams gaussParamsProposal -- -- bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState -- bernParamUpdater = updateBernParam bernParamProposal -- @ -- $gmmprop -- @ -- gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState -- gmmProposal = 'mixCondProposals' $ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2] -- @ -- $kernel -- @ -- gmmMH :: Step GaussianMixtureState -- gmmMH = 'metropolisHastings' gmmTarget gmmProposal -- @ -- $visual -- @ -- histogram :: Ord a => [a] -> Map.Map a Int -- histogram ls = foldl addElem Map.empty ls -- where addElem m e = Map.insertWith (+) e 1 m -- -- printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double) -- printFields = let f s = (labels s, gaussParams s, bernParam s) in map f -- -- printLabelN :: Int -> PrintF GaussianMixtureState Bool -- printLabelN n = let f s = labels s !! (n-1) in map f -- -- compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool) -- compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f -- -- printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO () -- printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls -- -- batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO () -- batchHist f n = 'pack' (printHist f) $ 'inBatches' (printHist f) n -- @ -- $main -- @ -- nPoints :: Int -- nPoints = 6 -- -- sampleData :: [Double] -- sampleData = [ 63.13941114139962, 132.02763712240528 -- , 62.59642260289356, 132.2616834236893 -- , 64.10610391933461, 62.143820541377934 ] -- -- gmmStart :: GaussianMixtureState -- gmmStart = GMM { labels = [True, True, True, False, False, False], -- gaussParams = ((63, 100), (132, 100)), -- bernParam = 0.5, -- obs = sampleData } -- -- gmmTest :: IO () -- gmmTest = do -- g <- MWC.createSystemRandom -- let a = batchHist (compareLabels 5 6) 50 -- e = 'every' 50 a -- c = 'every' 50 'collect' -- ls <- 'walk' gmmMH gmmStart (10^6) g c -- putStrLn \"Done\" -- print $ take 20 (map labels ls) -- @