-- | This is a simplified version of the Expectation-Maximization algorithm for
-- a two-component Gaussian mixture model. Cf. Hastie et al, The Elements of
-- Statistical Learning, Springer. Chapter 8.5.1.
module Statistics.EM.TwoGaussian
( emFix
, emStarts
) where
import qualified Data.Vector.Unboxed as VU
import Statistics.Distribution
import Statistics.Distribution.Normal
import Control.Monad.Fix (fix)
import Data.List (sort,maximumBy,tails)
import Data.Ord
type Weight = Double
type Normal = (Double,Double)
-- | Perform one EM step
emStep :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal)
emStep xs (w,(mu1,s1),(mu2,s2)) = (w',(mu1',s1'),(mu2',s2')) where
-- this is the expectation step
ys = VU.map responsibility xs
n1 = normalDistr mu1 s1
n2 = normalDistr mu2 s2
responsibility y = let
n1y = density n1 y
n2y = density n2 y
in w * n2y / ((1-w) * n1y + w * n2y)
-- this is the maximization step
w' = VU.sum ys / fromIntegral (VU.length ys)
-- new Gaussian 1
div1 = VU.sum . VU.map (\y -> 1-y) $ ys
mu1' = (VU.sum $ VU.zipWith (\x y -> (1-y) * x) xs ys) / div1
s1' = (VU.sum $ VU.zipWith (\x y -> (1-y) * (x-mu1')^2) xs ys) / div1
-- new Gaussian 2
div2 = VU.sum $ ys
mu2' = (VU.sum $ VU.zipWith (\x y -> y * x) xs ys) / div2
s2' = (VU.sum $ VU.zipWith (\x y -> y * (x-mu2')^2) xs ys) / div2
-- | Performs an infinite number of EM steps, iterating towards converging
-- parameters.
emIter :: VU.Vector Double -> (Weight,Normal,Normal) -> [(Weight,Normal,Normal)]
emIter xs theta = iterate (emStep xs) theta
-- | Finds the fix-points of the EM step iterations.
emFix :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal)
emFix xs theta = last . map fst . takeWhile f $ zip zs (tail zs) where
zs = emIter xs theta
f ( (w1,(mu11,s11),(mu12,s12)) , (w2,(mu21,s21),(mu22,s22)) ) = w1 =/= w2 && mu11 =/= mu21 && mu12 =/= mu22 && s11 =/= s21 && s12 =/= s22
a =/= b = abs (a-b) > epsilon
epsilon = 10.0 ^^ (-10)
-- | Finds the best fix-point with all elements 'xs' as starting points for the
-- means. It holds that mu_1 < mu_2.
emStarts :: VU.Vector Double -> (Weight,Normal,Normal)
emStarts xs = maximumBy (comparing loglikelihood) . map (emFix xs) $ [f xs mu1 mu2 | mu1 <- VU.toList xs, mu2 <- VU.toList xs, mu1 (x-sampleMu)^2) $ xs) / (fromIntegral $ VU.length xs)