| Safe Haskell | None |
|---|---|
| Language | Haskell2010 |
MCMC.Examples.GMM
Contents
Description
Sampler for Gaussian Mixture Model
Here is the code in the Hakaru language for generating the data used in this example:
p <- unconditioned (beta 2 2)
[m1,m2] <- replicateM 2 $ unconditioned (normal 100 30)
[s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2)
let makePoint = do
b <- unconditioned (bern p)
unconditioned (ifThenElse b (normal m1 s1)
(normal m2 s2))
replicateM nPoints makePointDocumentation
data GaussianMixtureState
Target
Focus combinators
focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState focusLabels t =makeTargetdens where dens (GMM l _ p _) =densityt (p,l) focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState focusGaussParams t =makeTarget(densityt . gaussParams) focusBernParam :: Target Double -> Target GaussianMixtureState focusBernParam t =makeTarget(densityt . bernParam) focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) -> Target GaussianMixtureState focusObs t =makeTargetdens where dens (GMM l gps _ o) =densityt (l, gps, o)
Record field targets
labelsTarget :: Target (Double, [Bool]) labelsTarget =makeTarget$ (p,ls) -> product $ map (density$bernp) ls gaussParamsTarget :: Target ((Double, Double), (Double, Double)) gaussParamsTarget =makeTargetdens where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2 mdens m =density(normal100 900) m cdens c =density(uniform0 200) c bernParamTarget :: Target Double bernParamTarget =fromProposal(beta2 2) obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) obsTarget =makeTargetdens where dens (ls, ((m1, c1), (m2, c2)), os) = let ols = zip os ls gauss l = if l thennormalm1 (c1*c1) elsenormalm2 (c2*c2) in product $ map ((o,l) ->density(gauss l) o) ols
Target density factors
labelsFactor :: Target GaussianMixtureState labelsFactor = focusLabels labelsTarget gaussParamsFactor :: Target GaussianMixtureState gaussParamsFactor = focusGaussParams gaussParamsTarget bernParamFactor :: Target GaussianMixtureState bernParamFactor = focusBernParam bernParamTarget obsFactor :: Target GaussianMixtureState obsFactor = focusObs obsTarget
Target density
gmmTarget :: Target GaussianMixtureState gmmTarget =makeTarget$productDensity[labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor]
Proposal
Proposal update boilerplate
updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState updateLabels f x =makeProposaldens sf where dens y =density(f $ labels x) (labels y) sf g = do newLabels <-sampleFrom(f $ labels x) g return x { labels = newLabels } updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))) -> GaussianMixtureState -> Proposal GaussianMixtureState updateGaussParams f x =makeProposaldens sf where dens y =density(f $ gaussParams x) (gaussParams y) sf g = do newParams <-sampleFrom(f $ gaussParams x) g return x { gaussParams = newParams } updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState updateBernParam f x =makeProposaldens sf where dens y =density(f $ bernParam x) (bernParam y) sf g = do newParam <-sampleFrom(f $ bernParam x) g return x { bernParam = newParam }
Field proposals
labelsProposal :: [Bool] -> Proposal [Bool] labelsProposal ls =chooseProposalnPoints (n ->updateNthn flipBool ls) where flipBool bn = if bn thenbern0 elsebern1 gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) gaussParamsProposal params =mixProposals$ zip [m1p, c1p, m2p, c2p] (repeat 1) where condProp c =normalc 1 m1p =updateFirst(updateFirstcondProp) params c1p =updateFirst(updateSecondcondProp) params m2p =updateSecond(updateFirstcondProp) params c2p =updateSecond(updateSecondcondProp) params bernParamProposal :: Double -> Proposal Double bernParamProposal p =uniform(p2) (1-p2)
Field updaters
labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState labelsUpdater = updateLabels labelsProposal gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState gaussParamsUpdater = updateGaussParams gaussParamsProposal bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState bernParamUpdater = updateBernParam bernParamProposal
The combined proposal
gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState
gmmProposal = mixCondProposals $ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2]
Running the sampler
Transition kernel
gmmMH :: Step GaussianMixtureState
gmmMH = metropolisHastings gmmTarget gmmProposal
Visualization methods
histogram :: Ord a => [a] -> Map.Map a Int
histogram ls = foldl addElem Map.empty ls
where addElem m e = Map.insertWith (+) e 1 m
printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double)
printFields = let f s = (labels s, gaussParams s, bernParam s) in map f
printLabelN :: Int -> PrintF GaussianMixtureState Bool
printLabelN n = let f s = labels s !! (n-1) in map f
compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool)
compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f
printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO ()
printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls
batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO ()
batchHist f n = pack (printHist f) $ inBatches (printHist f) n
Main
nPoints :: Int
nPoints = 6
sampleData :: [Double]
sampleData = [ 63.13941114139962, 132.02763712240528
, 62.59642260289356, 132.2616834236893
, 64.10610391933461, 62.143820541377934 ]
gmmStart :: GaussianMixtureState
gmmStart = GMM { labels = [True, True, True, False, False, False],
gaussParams = ((63, 100), (132, 100)),
bernParam = 0.5,
obs = sampleData }
gmmTest :: IO ()
gmmTest = do
g <- MWC.createSystemRandom
let a = batchHist (compareLabels 5 6) 50
e = every 50 a
c = every 50 collect
ls <- walk gmmMH gmmStart (10^6) g c
putStrLn "Done"
print $ take 20 (map labels ls)