-- | This is a simplified version of the Expectation-Maximization algorithm for -- a two-component Gaussian mixture model. Cf. Hastie et al, The Elements of -- Statistical Learning, Springer. Chapter 8.5.1. module Statistics.EM.TwoGaussian ( emFix , emStarts ) where import qualified Data.Vector.Unboxed as VU import Statistics.Distribution import Statistics.Distribution.Normal import Control.Monad.Fix (fix) import Data.List (sort,maximumBy,tails) import Data.Ord type Weight = Double type Normal = (Double,Double) -- | Perform one EM step emStep :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal) emStep xs (w,(mu1,s1),(mu2,s2)) = (w',(mu1',s1'),(mu2',s2')) where -- this is the expectation step ys = VU.map responsibility xs n1 = normalDistr mu1 s1 n2 = normalDistr mu2 s2 responsibility y = let n1y = density n1 y n2y = density n2 y in w * n2y / ((1-w) * n1y + w * n2y) -- this is the maximization step w' = VU.sum ys / fromIntegral (VU.length ys) -- new Gaussian 1 div1 = VU.sum . VU.map (\y -> 1-y) $ ys mu1' = (VU.sum $ VU.zipWith (\x y -> (1-y) * x) xs ys) / div1 s1' = (VU.sum $ VU.zipWith (\x y -> (1-y) * (x-mu1')^2) xs ys) / div1 -- new Gaussian 2 div2 = VU.sum $ ys mu2' = (VU.sum $ VU.zipWith (\x y -> y * x) xs ys) / div2 s2' = (VU.sum $ VU.zipWith (\x y -> y * (x-mu2')^2) xs ys) / div2 -- | Performs an infinite number of EM steps, iterating towards converging -- parameters. emIter :: VU.Vector Double -> (Weight,Normal,Normal) -> [(Weight,Normal,Normal)] emIter xs theta = iterate (emStep xs) theta -- | Finds the fix-points of the EM step iterations. emFix :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal) emFix xs theta = last . map fst . takeWhile f $ zip zs (tail zs) where zs = emIter xs theta f ( (w1,(mu11,s11),(mu12,s12)) , (w2,(mu21,s21),(mu22,s22)) ) = w1 =/= w2 && mu11 =/= mu21 && mu12 =/= mu22 && s11 =/= s21 && s12 =/= s22 a =/= b = abs (a-b) > epsilon epsilon = 10.0 ^^ (-10) -- | Finds the best fix-point with all elements 'xs' as starting points for the -- means. It holds that mu_1 < mu_2. emStarts :: VU.Vector Double -> (Weight,Normal,Normal) emStarts xs = maximumBy (comparing loglikelihood) . map (emFix xs) $ [f xs mu1 mu2 | mu1 <- VU.toList xs, mu2 <- VU.toList xs, mu1 (x-sampleMu)^2) $ xs) / (fromIntegral $ VU.length xs)