Safe Haskell | None |
---|---|
Language | Haskell2010 |
Sampler for Gaussian Mixture Model
Here is the code in the Hakaru language for generating the data used in this example:
p <- unconditioned (beta 2 2) [m1,m2] <- replicateM 2 $ unconditioned (normal 100 30) [s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2) let makePoint = do b <- unconditioned (bern p) unconditioned (ifThenElse b (normal m1 s1) (normal m2 s2)) replicateM nPoints makePoint
Documentation
data GaussianMixtureState
Target
Focus combinators
focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState focusLabels t =makeTarget
dens where dens (GMM l _ p _) =density
t (p,l) focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState focusGaussParams t =makeTarget
(density
t . gaussParams) focusBernParam :: Target Double -> Target GaussianMixtureState focusBernParam t =makeTarget
(density
t . bernParam) focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) -> Target GaussianMixtureState focusObs t =makeTarget
dens where dens (GMM l gps _ o) =density
t (l, gps, o)
Record field targets
labelsTarget :: Target (Double, [Bool]) labelsTarget =makeTarget
$ (p,ls) -> product $ map (density
$bern
p) ls gaussParamsTarget :: Target ((Double, Double), (Double, Double)) gaussParamsTarget =makeTarget
dens where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2 mdens m =density
(normal
100 900) m cdens c =density
(uniform
0 200) c bernParamTarget :: Target Double bernParamTarget =fromProposal
(beta
2 2) obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double]) obsTarget =makeTarget
dens where dens (ls, ((m1, c1), (m2, c2)), os) = let ols = zip os ls gauss l = if l thennormal
m1 (c1*c1) elsenormal
m2 (c2*c2) in product $ map ((o,l) ->density
(gauss l) o) ols
Target density factors
labelsFactor :: Target GaussianMixtureState labelsFactor = focusLabels labelsTarget gaussParamsFactor :: Target GaussianMixtureState gaussParamsFactor = focusGaussParams gaussParamsTarget bernParamFactor :: Target GaussianMixtureState bernParamFactor = focusBernParam bernParamTarget obsFactor :: Target GaussianMixtureState obsFactor = focusObs obsTarget
Target density
gmmTarget :: Target GaussianMixtureState gmmTarget =makeTarget
$productDensity
[labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor]
Proposal
Proposal update boilerplate
updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState updateLabels f x =makeProposal
dens sf where dens y =density
(f $ labels x) (labels y) sf g = do newLabels <-sampleFrom
(f $ labels x) g return x { labels = newLabels } updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))) -> GaussianMixtureState -> Proposal GaussianMixtureState updateGaussParams f x =makeProposal
dens sf where dens y =density
(f $ gaussParams x) (gaussParams y) sf g = do newParams <-sampleFrom
(f $ gaussParams x) g return x { gaussParams = newParams } updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState updateBernParam f x =makeProposal
dens sf where dens y =density
(f $ bernParam x) (bernParam y) sf g = do newParam <-sampleFrom
(f $ bernParam x) g return x { bernParam = newParam }
Field proposals
labelsProposal :: [Bool] -> Proposal [Bool] labelsProposal ls =chooseProposal
nPoints (n ->updateNth
n flipBool ls) where flipBool bn = if bn thenbern
0 elsebern
1 gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) gaussParamsProposal params =mixProposals
$ zip [m1p, c1p, m2p, c2p] (repeat 1) where condProp c =normal
c 1 m1p =updateFirst
(updateFirst
condProp) params c1p =updateFirst
(updateSecond
condProp) params m2p =updateSecond
(updateFirst
condProp) params c2p =updateSecond
(updateSecond
condProp) params bernParamProposal :: Double -> Proposal Double bernParamProposal p =uniform
(p2) (1-p2)
Field updaters
labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState labelsUpdater = updateLabels labelsProposal gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState gaussParamsUpdater = updateGaussParams gaussParamsProposal bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState bernParamUpdater = updateBernParam bernParamProposal
The combined proposal
gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState
gmmProposal = mixCondProposals
$ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2]
Running the sampler
Transition kernel
gmmMH :: Step GaussianMixtureState
gmmMH = metropolisHastings
gmmTarget gmmProposal
Visualization methods
histogram :: Ord a => [a] -> Map.Map a Int histogram ls = foldl addElem Map.empty ls where addElem m e = Map.insertWith (+) e 1 m printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double) printFields = let f s = (labels s, gaussParams s, bernParam s) in map f printLabelN :: Int -> PrintF GaussianMixtureState Bool printLabelN n = let f s = labels s !! (n-1) in map f compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool) compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO () printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO () batchHist f n =pack
(printHist f) $inBatches
(printHist f) n
Main
nPoints :: Int nPoints = 6 sampleData :: [Double] sampleData = [ 63.13941114139962, 132.02763712240528 , 62.59642260289356, 132.2616834236893 , 64.10610391933461, 62.143820541377934 ] gmmStart :: GaussianMixtureState gmmStart = GMM { labels = [True, True, True, False, False, False], gaussParams = ((63, 100), (132, 100)), bernParam = 0.5, obs = sampleData } gmmTest :: IO () gmmTest = do g <- MWC.createSystemRandom let a = batchHist (compareLabels 5 6) 50 e =every
50 a c =every
50collect
ls <-walk
gmmMH gmmStart (10^6) g c putStrLn "Done" print $ take 20 (map labels ls)