-- | EM for a mixture of k one-dimensional Gaussians. This procedure tends to
-- produce "NaN"s whenever more Gaussians are being selected than are called
-- for. This is rather convenient. ;-)
--
-- TODO cite paper

module Statistics.EM.GMM
  ( emFix
  , emStarts
  ) where

import Control.Monad.Fix (fix)
import Data.List (sort,maximumBy,tails,inits,genericLength)
import Data.Ord
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Statistics.Distribution
import Statistics.Distribution.Normal
import Data.Tuple.Select (sel2)



type Data   = VU.Vector Double
type Theta  = VU.Vector (Double,Double,Double) -- weight, mean, variance
type ThetaL = [(Double,Double,Double)] -- weight, mean, variance

-- | Perform one EM step given the data. In General, emSteps should be iterated
-- until some convergence criterion is met.

emStep :: Data -> Theta -> Theta
emStep xs ts = ts' where
  -- E-step
  -- TODO this could be made easier by using hmatrix with enabled "vector"
  resps = V.map (\i -> calcResV (ts `VU.unsafeIndex` i)) $ V.enumFromN 0 tlen
  calcResV t = VU.map (calcRes t) xs
  calcRes (w,mu,s) x = (w * density (normalDistr mu s) x) / (VU.sum . VU.map (\(w',mu',s') -> w'* density (normalDistr mu' s') x) $ ts)
  ns = VU.map (\i -> VU.sum $ resps `V.unsafeIndex` i) $ VU.enumFromN 0 tlen
  -- M-step
  ws = VU.map (\w -> w / fromIntegral (VU.length xs)) ns
  mus = VU.map (\i ->
                  (1/ ns `VU.unsafeIndex` i) *
                  (VU.sum $ VU.zipWith (*) (resps `V.unsafeIndex` i) xs))
                $ VU.enumFromN 0 tlen
  ss = VU.map (\i ->
                  (1/ ns `VU.unsafeIndex` i) *
                  (VU.sum $ VU.zipWith (*) (resps `V.unsafeIndex` i) (VU.map (\x -> (x - (mus `VU.unsafeIndex` i))^2) xs)))
                $ VU.enumFromN 0 tlen
  ts' = VU.zip3 ws mus ss
  tlen = VU.length ts

-- | Produces an infinite list of 'Theta's that will (should) convergence
-- toward a local optimum.

emIter :: Data -> Theta -> [Theta]
emIter xs ts = iterate (emStep xs) ts

-- | Find an optimal set of parameters 'Theta'. The additional "takeWhile (not
-- . isnan . fst)" makes sure that in cases of overfitting, 'emFix' does
-- terminate. Due to the way we check and take, in case of NaNs, the returned
-- values will be NaNs (checking fst, returning snd).

emFix :: Data -> Theta -> Theta
emFix xs ts = res where
  res = last . map snd . takeWhile (not . isnan . fst) . takeWhile (not . converged) $ zip ys zs
  ys = emIter xs ts
  zs = tail ys
  converged (y,z) = abs (logLikelihood y xs - logLikelihood z xs) < epsilon
  epsilon = 10 ^^ (-10)
  isnan ns = let (ws,_,_) = VU.unzip3 ns in VU.any isNaN ws

-- | Calculate the log-likelihood for a given set of parameters 'Theta' and
-- some data 'Data'. Used by 'emFix' to estimate if convergence is reached.
--
-- TODO could be useful in a more general setting within StatisticalMethods.

logLikelihood :: Theta -> Data -> Double
logLikelihood ts xs = (VU.sum . VU.map lls $ xs) / (fromIntegral $ VU.length xs) where
  lls x = VU.sum . VU.map (\t -> ll t x) $ ts
  ll (w,mu,s) x = w * density (normalDistr mu s) x

-- | Given a set of 'Data' and a number 'k' of Gaussian peaks, try to find the
-- optimal GMM. This is done by trying each data point as mu for each Gaussian.
-- Note that this will be rather slow for larger 'k' (larger than, say 2 or 3).
-- In that case, a random-drawing method should be chosen.
--
-- TODO xs' -> xs sorting makes me cry!

emStarts :: Int -> Data -> Theta
emStarts k xs' = maximumBy (comparing (\t -> logLikelihood t xs)) . map (emFix xs) $ ts where
  ts = map VU.fromList . f k . VU.toList $ xs
  mkT mu = (w,mu,sampleVar)
  f l zs
    | l< 1 = error "emStarts called with k<1"
    | l==1 = map (\z -> [mkT z]) zs
    | l> 1 = [mkT y : ys | y <- zs, ys <- f (l-1) (dropWhile (<y) zs)]
  sampleMu = VU.sum xs / (fromIntegral $ VU.length xs)
  sampleVar = (VU.sum . VU.map (\x -> (x-sampleMu)^2) $ xs) / (fromIntegral $ VU.length xs)
  w = 1 / fromIntegral k
  xs = VU.fromList . sort . VU.toList $ xs'