-- | Optimized sampler for Gaussian Mixture Model -- -- Here is the code in the Hakaru language for generating -- the data used in this example: -- -- > p <- unconditioned (beta 2 2) -- > [m1,m2] <- replicateM 2 $ unconditioned (normal 100 30) -- > [s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2) -- > let makePoint = do -- > b <- unconditioned (bern p) -- > unconditioned (ifThenElse b (normal m1 s1) -- > (normal m2 s2)) -- > replicateM nPoints makePoint module MCMC.Examples.HandwrittenGMM ( GaussianMixtureState(..) -- * Focused targets -- $tar -- * Focused proposals -- $prop -- * Focused steps -- *** Each step computes only those parts of the density ratio that its proposal affects - the other parts would cancel out -- $steps -- * Optimized sampler -- *** A mixture of focused, i.e. optimized steps -- $sampler -- * Main -- $main ) where import MCMC.Types import MCMC.Kernels import MCMC.Distributions import MCMC.Actions import MCMC.Combinators import qualified System.Random.MWC as MWC data GaussianMixtureState = GMM { labels :: [Bool] , gaussParams :: ((Double, Double), (Double, Double)) , bernParam :: Double } nPoints :: Int nPoints = 6 stepLabels :: [Double] -> Step GaussianMixtureState stepLabels obs = chooseStep nPoints (\i -> makeTarget $ dens i) labelsProposal metropolisHastings where dens i state = density (targetLabel i) state * density (targetObs i obs) state -- This could be optimized further if we know the label corresponding -- to the gaussian to which the updated param belongs. stepGaussParams :: [Double] -> Step GaussianMixtureState stepGaussParams obs = metropolisHastings (makeTarget dens) gaussParamsProposal where dens state = density targetGaussParams state * product [density (targetObs i obs) state | i <- [1..nPoints]] stepBernParam :: Step GaussianMixtureState stepBernParam = metropolisHastings (makeTarget dens) bernParamProposal where dens state = density targetBernParam state * product [density (targetLabel i) state | i <- [1..nPoints]] gmmSampler :: [Double] -> Step GaussianMixtureState gmmSampler obs = mixSteps $ zip [(stepLabels obs), (stepGaussParams obs), stepBernParam] [1,1,1] -- | Main sampleData :: [Double] sampleData = [ 63.13941114139962, 132.02763712240528 , 62.59642260289356, 132.2616834236893 , 64.10610391933461, 62.143820541377934 ] startState :: GaussianMixtureState startState = GMM { -- labels = [True, True, True, False, False, False], labels = [False, False, False, True, True, True], gaussParams = ((63, 100), (132, 100)), bernParam = 0.5 } test :: IO () test = do g <- MWC.createSystemRandom let c = every 50 collect p = every 1 (display labels) -- ls <- walk (gmmSampler sampleData) startState (10^6) g c -- print $ take 20 (map labels ls) walk (gmmSampler sampleData) startState (10^2) g p -- | Targets -- -- Labels targetLabel :: Int -> Target GaussianMixtureState targetLabel i = makeTarget (densityLabel i) densityLabel :: Int -> GaussianMixtureState -> Double densityLabel i (GMM l _ p) = if (l !! (i-1)) then p else 1-p -- Gauss params targetGaussParams :: Target GaussianMixtureState targetGaussParams = makeTarget densityGaussParams densityGaussParams :: GaussianMixtureState -> Double densityGaussParams state = mdens m1 * mdens m2 * cdens c1 * cdens c2 where ((m1, c1), (m2, c2)) = gaussParams state mdens m = density (normal 100 900) m cdens c = density (uniform 0 2) c -- Bern param targetBernParam :: Target GaussianMixtureState targetBernParam = makeTarget densityBernParam densityBernParam :: GaussianMixtureState -> Double densityBernParam state = density (beta 2 2) (bernParam state) -- Obs / data points targetObs :: Int -> [Double] -> Target GaussianMixtureState targetObs i obs = makeTarget (densityObs i obs) densityObs :: Int -> [Double] -> GaussianMixtureState -> Double densityObs i obs state = if labels state !! (i-1) then density (normal m1 c1) oi else density (normal m2 c2) oi where oi = obs !! (i-1) ((m1, c1), (m2, c2)) = gaussParams state -- | Proposals -- -- Labels labelsProposal :: Int -> GaussianMixtureState -> Proposal GaussianMixtureState labelsProposal i x = makeProposal dens sf where dens y = density (updateLabel i $ labels x) (labels y) sf g = do newLabels <- sampleFrom (updateLabel i $ labels x) g return x { labels = newLabels } updateLabel :: Int -> [Bool] -> Proposal [Bool] updateLabel i ls = updateNth i flipBool ls where flipBool bn = if bn then bern 0 else bern 1 -- Gauss params gaussParamsProposal :: GaussianMixtureState -> Proposal GaussianMixtureState gaussParamsProposal x = makeProposal dens sf where dens y = density (updateGaussParams $ gaussParams x) (gaussParams y) sf g = do newParams <- sampleFrom (updateGaussParams $ gaussParams x) g return x { gaussParams = newParams } updateGaussParams :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) updateGaussParams params = mixProposals $ zip [m1p, c1p, m2p, c2p] (repeat 1) where condProp c = normal c 1 m1p = updateFirst (updateFirst condProp) params c1p = updateFirst (updateSecond condProp) params m2p = updateSecond (updateFirst condProp) params c2p = updateSecond (updateSecond condProp) params -- Bern param bernParamProposal :: GaussianMixtureState -> Proposal GaussianMixtureState bernParamProposal x = makeProposal dens sf where dens y = density (updateBernParam $ bernParam x) (bernParam y) sf g = do newParam <- sampleFrom (updateBernParam $ bernParam x) g return x { bernParam = newParam } updateBernParam :: Double -> Proposal Double updateBernParam p = uniform (p/2) (1-p/2) ----------------- -- Documentation ----------------- -- $tar -- @ -- targetLabel :: Int -> Target GaussianMixtureState -- targetLabel i = 'makeTarget' (densityLabel i) -- -- densityLabel :: Int -> GaussianMixtureState -> Double -- densityLabel i (GMM l _ p) = if (l !! (i-1)) then p else 1-p -- -- -- targetGaussParams :: Target GaussianMixtureState -- targetGaussParams = 'makeTarget' densityGaussParams -- -- densityGaussParams :: GaussianMixtureState -> Double -- densityGaussParams state = mdens m1 * mdens m2 * cdens c1 * cdens c2 -- where ((m1, c1), (m2, c2)) = gaussParams state -- mdens m = 'density' ('normal' 100 900) m -- cdens c = 'density' ('uniform' 0 2) c -- -- -- targetBernParam :: Target GaussianMixtureState -- targetBernParam = 'makeTarget' densityBernParam -- -- densityBernParam :: GaussianMixtureState -> Double -- densityBernParam state = 'density' ('beta' 2 2) (bernParam state) -- -- -- targetObs :: Int -> [Double] -> Target GaussianMixtureState -- targetObs i obs = 'makeTarget' (densityObs i obs) -- -- densityObs :: Int -> [Double] -> GaussianMixtureState -> Double -- densityObs i obs state = if labels state !! (i-1) -- then 'density' ('normal' m1 c1) oi -- else 'density' ('normal' m2 c2) oi -- where oi = obs !! (i-1) -- ((m1, c1), (m2, c2)) = gaussParams state -- @ -- $prop -- @ -- labelsProposal :: Int -> GaussianMixtureState -> Proposal GaussianMixtureState -- labelsProposal i x = 'makeProposal' dens sf -- where dens y = 'density' (updateLabel i $ labels x) (labels y) -- sf g = do newLabels <- 'sampleFrom' (updateLabel i $ labels x) g -- return x { labels = newLabels } -- -- updateLabel :: Int -> [Bool] -> Proposal [Bool] -- updateLabel i ls = 'updateNth' i flipBool ls -- where flipBool bn = if bn then 'bern' 0 else 'bern' 1 -- -- -- gaussParamsProposal :: GaussianMixtureState -> Proposal GaussianMixtureState -- gaussParamsProposal x = 'makeProposal' dens sf -- where dens y = 'density' (updateGaussParams $ gaussParams x) (gaussParams y) -- sf g = do newParams <- 'sampleFrom' (updateGaussParams $ gaussParams x) g -- return x { gaussParams = newParams } -- -- updateGaussParams :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) -- updateGaussParams params = 'mixProposals' $ zip [m1p, c1p, m2p, c2p] (repeat 1) -- where condProp c = 'normal' c 1 -- m1p = 'updateFirst' ('updateFirst' condProp) params -- c1p = 'updateFirst' ('updateSecond' condProp) params -- m2p = 'updateSecond' ('updateFirst' condProp) params -- c2p = 'updateSecond' ('updateSecond' condProp) params -- -- -- bernParamProposal :: GaussianMixtureState -> Proposal GaussianMixtureState -- bernParamProposal x = 'makeProposal' dens sf -- where dens y = 'density' (updateBernParam $ bernParam x) (bernParam y) -- sf g = do newParam <- 'sampleFrom' (updateBernParam $ bernParam x) g -- return x { bernParam = newParam } -- -- updateBernParam :: Double -> Proposal Double -- updateBernParam p = 'uniform' (p/2) (1-p/2) -- @ -- $steps -- @ -- stepLabels :: [Double] -> Step GaussianMixtureState -- stepLabels obs = 'chooseStep' nPoints -- (\i -> 'makeTarget' $ dens i) labelsProposal 'metropolisHastings' -- where dens i state = 'density' (targetLabel i) state * -- 'density' (targetObs i obs) state -- -- -- This could be optimized further if we know the label corresponding -- -- to the gaussian to which the updated param belongs. -- stepGaussParams :: [Double] -> Step GaussianMixtureState -- stepGaussParams obs = 'metropolisHastings' ('makeTarget' dens) gaussParamsProposal -- where dens state = 'density' targetGaussParams state * -- product ['density' (targetObs i obs) state | i <- [1..nPoints]] -- -- stepBernParam :: Step GaussianMixtureState -- stepBernParam = 'metropolisHastings' ('makeTarget' dens) bernParamProposal -- where dens state = 'density' targetBernParam state * -- product ['density' (targetLabel i) state | i <- [1..nPoints]] -- @ -- $sampler -- @ -- gmmSampler :: [Double] -> Step GaussianMixtureState -- gmmSampler obs = 'mixSteps' $ -- zip [(stepLabels obs), (stepGaussParams obs), stepBernParam] [1,1,1] -- @ -- $main -- @ -- nPoints :: Int -- nPoints = 6 -- -- sampleData :: [Double] -- sampleData = [ 63.13941114139962, 132.02763712240528 -- , 62.59642260289356, 132.2616834236893 -- , 64.10610391933461, 62.143820541377934 ] -- -- startState :: GaussianMixtureState -- startState = GMM { -- labels = [True, True, True, False, False, False], -- labels = [False, False, False, True, True, True], -- gaussParams = ((63, 100), (132, 100)), -- bernParam = 0.5 } -- -- test :: IO () -- test = do -- g <- MWC.createSystemRandom -- let c = 'every' 50 'collect' -- p = 'every' 1 ('display' labels) -- -- ls <- 'walk' (gmmSampler sampleData) startState (10^6) g c -- -- print $ take 20 (map labels ls) -- 'walk' (gmmSampler sampleData) startState (10^2) g p -- @