module Statistics.EM.GMM
( emFix
, emStarts
) where
import Control.Monad.Fix (fix)
import Data.List (sort,maximumBy,tails,inits,genericLength)
import Data.Ord
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Statistics.Distribution
import Statistics.Distribution.Normal
import Data.Tuple.Select (sel2)
type Data = VU.Vector Double
type Theta = VU.Vector (Double,Double,Double)
type ThetaL = [(Double,Double,Double)]
emStep :: Data -> Theta -> Theta
emStep xs ts = ts' where
resps = V.map (\i -> calcResV (ts `VU.unsafeIndex` i)) $ V.enumFromN 0 tlen
calcResV t = VU.map (calcRes t) xs
calcRes (w,mu,s) x = (w * density (normalDistr mu s) x) / (VU.sum . VU.map (\(w',mu',s') -> w'* density (normalDistr mu' s') x) $ ts)
ns = VU.map (\i -> VU.sum $ resps `V.unsafeIndex` i) $ VU.enumFromN 0 tlen
ws = VU.map (\w -> w / fromIntegral (VU.length xs)) ns
mus = VU.map (\i ->
(1/ ns `VU.unsafeIndex` i) *
(VU.sum $ VU.zipWith (*) (resps `V.unsafeIndex` i) xs))
$ VU.enumFromN 0 tlen
ss = VU.map (\i ->
(1/ ns `VU.unsafeIndex` i) *
(VU.sum $ VU.zipWith (*) (resps `V.unsafeIndex` i) (VU.map (\x -> (x (mus `VU.unsafeIndex` i))^2) xs)))
$ VU.enumFromN 0 tlen
ts' = VU.zip3 ws mus ss
tlen = VU.length ts
emIter :: Data -> Theta -> [Theta]
emIter xs ts = iterate (emStep xs) ts
emFix :: Data -> Theta -> Theta
emFix xs ts = res where
res = last . map snd . takeWhile (not . isnan . fst) . takeWhile (not . converged) $ zip ys zs
ys = emIter xs ts
zs = tail ys
converged (y,z) = abs (logLikelihood y xs logLikelihood z xs) < epsilon
epsilon = 10 ^^ (10)
isnan ns = let (ws,_,_) = VU.unzip3 ns in VU.any isNaN ws
logLikelihood :: Theta -> Data -> Double
logLikelihood ts xs = (VU.sum . VU.map lls $ xs) / (fromIntegral $ VU.length xs) where
lls x = VU.sum . VU.map (\t -> ll t x) $ ts
ll (w,mu,s) x = w * density (normalDistr mu s) x
emStarts :: Int -> Data -> Theta
emStarts k xs' = maximumBy (comparing (\t -> logLikelihood t xs)) . map (emFix xs) $ ts where
ts = map VU.fromList . f k . VU.toList $ xs
mkT mu = (w,mu,sampleVar)
f l zs
| l< 1 = error "emStarts called with k<1"
| l==1 = map (\z -> [mkT z]) zs
| l> 1 = [mkT y : ys | y <- zs, ys <- f (l1) (dropWhile (<y) zs)]
sampleMu = VU.sum xs / (fromIntegral $ VU.length xs)
sampleVar = (VU.sum . VU.map (\x -> (xsampleMu)^2) $ xs) / (fromIntegral $ VU.length xs)
w = 1 / fromIntegral k
xs = VU.fromList . sort . VU.toList $ xs'