-- | This is a simplified version of the Expectation-Maximization algorithm for
-- a two-component Gaussian mixture model. Cf. Hastie et al, The Elements of
-- Statistical Learning, Springer. Chapter 8.5.1.

module Statistics.EM.TwoGaussian
  ( emFix
  , emStarts
  ) where

import qualified Data.Vector.Unboxed as VU
import Statistics.Distribution
import Statistics.Distribution.Normal
import Control.Monad.Fix (fix)
import Data.List (sort,maximumBy,tails)
import Data.Ord

type Weight = Double
type Normal = (Double,Double)

-- | Perform one EM step

emStep :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal)
emStep xs (w,(mu1,s1),(mu2,s2)) = (w',(mu1',s1'),(mu2',s2')) where
  -- this is the expectation step
  ys = VU.map responsibility xs
  n1 = normalDistr mu1 s1
  n2 = normalDistr mu2 s2
  responsibility y =  let
                        n1y = density n1 y
                        n2y = density n2 y
                      in w * n2y / ((1-w) * n1y + w * n2y)
  -- this is the maximization step
  w' = VU.sum ys / fromIntegral (VU.length ys)
    -- new Gaussian 1
  div1 = VU.sum . VU.map (\y -> 1-y) $ ys
  mu1' = (VU.sum $ VU.zipWith (\x y -> (1-y) * x) xs ys) / div1
  s1'  = (VU.sum $ VU.zipWith (\x y -> (1-y) * (x-mu1')^2) xs ys) / div1
    -- new Gaussian 2
  div2 = VU.sum $ ys
  mu2' = (VU.sum $ VU.zipWith (\x y -> y * x) xs ys) / div2
  s2'  = (VU.sum $ VU.zipWith (\x y -> y * (x-mu2')^2) xs ys) / div2

-- | Performs an infinite number of EM steps, iterating towards converging
-- parameters.

emIter :: VU.Vector Double -> (Weight,Normal,Normal) -> [(Weight,Normal,Normal)]
emIter xs theta = iterate (emStep xs) theta

-- | Finds the fix-points of the EM step iterations.

emFix :: VU.Vector Double -> (Weight,Normal,Normal) -> (Weight,Normal,Normal)
emFix xs theta = last . map fst . takeWhile f $ zip zs (tail zs) where
  zs = emIter xs theta
  f ( (w1,(mu11,s11),(mu12,s12)) , (w2,(mu21,s21),(mu22,s22)) ) = w1 =/= w2 && mu11 =/= mu21 && mu12 =/= mu22 && s11 =/= s21 && s12 =/= s22
  a =/= b = abs (a-b) > epsilon
  epsilon = 10.0 ^^ (-10)

-- | Finds the best fix-point with all elements 'xs' as starting points for the
-- means. It holds that mu_1 < mu_2.

emStarts :: VU.Vector Double -> (Weight,Normal,Normal)
emStarts xs = maximumBy (comparing loglikelihood) . map (emFix xs) $ [f xs mu1 mu2 | mu1 <- VU.toList xs, mu2 <- VU.toList xs, mu1<mu2] where
  f xs mu1 mu2 = (0.5,(mu1,sampleVar),(mu2,sampleVar))
  loglikelihood (w,(mu1,s1),(mu2,s2)) = VU.sum . VU.map ll $ xs where
    ll x = log $ (1-w) * density n1 x + w * density n2 x
    n1 = normalDistr mu1 s1
    n2 = normalDistr mu2 s2
  sampleMu = VU.sum xs / (fromIntegral $ VU.length xs)
  sampleVar = (VU.sum . VU.map (\x -> (x-sampleMu)^2) $ xs) / (fromIntegral $ VU.length xs)