--------------------------------------------------------------------------------
-- |
-- Module      :  $Header$
-- Copyright   :  (c) 2015 Kai Zhang
-- License     :  MIT
-- Maintainer  :  kai@kzhang.org
-- Stability   :  experimental
-- Portability :  portable
--
-- Kmeans clustering
--------------------------------------------------------------------------------
{-# LANGUAGE FlexibleContexts  #-}

module AI.Clustering.KMeans
    ( kmeans
    , kmeansWith
    , forgyMethod
    ) where

import Control.Monad (forM_)
import Control.Monad.Primitive (PrimMonad, PrimState)
import qualified Data.Matrix.Generic as M
import qualified Data.Matrix.Generic.Mutable as MM
import Data.Ord (comparing)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import Data.List (minimumBy, nub)
import System.Random.MWC (uniformR, Gen)

-- | Lloyd's algorithm, also known as K-means algorithm
kmeans :: (G.Vector v Double, G.Vector v Int, Eq (v Int), Eq (v Double), PrimMonad m)
       => Gen (PrimState m)
       -> Int                           -- ^ number of clusters
       -> M.Matrix v Double             -- ^ each row represents a point
       -> m (v Int, M.Matrix v Double)  -- ^ membership vector
kmeans g k mat = do
    initial <- forgyMethod g k mat
    return $ kmeansWith mat initial
{-# INLINE kmeans #-}

-- | Lloyd's algorithm, also known as K-means algorithm
kmeansWith :: (G.Vector v Double, G.Vector v Int, Eq (v Int))
           => M.Matrix v Double           -- ^ initial set of k centroids
           -> M.Matrix v Double           -- ^ each row represents a point
           -> (v Int, M.Matrix v Double)  -- ^ membership vector and centroids
kmeansWith initial mat | d /= M.cols initial || k > n = error "check input"
                       | otherwise = loop initial G.empty
  where
    loop means membership
        | membership' == membership = (membership, means)
        | otherwise = loop (update membership') membership'
      where
        membership' = assign means

    -- Assignment step
    assign means = G.generate n $ \i ->
        let x = M.takeRow mat i
        in fst $ minimumBy (comparing snd) $ zip [0..k-1] $ map (dist x) $ M.toRows means

    --  Update step
    update membership = MM.create $ do
        m <- MM.replicate (k,d) 0.0
        count <- UM.replicate k (0 :: Int)
        forM_ [0..n-1] $ \i -> do
            let x = membership G.! i
            GM.unsafeRead count x >>= GM.unsafeWrite count x . (+1)
            forM_ [0..d-1] $ \j ->
                MM.unsafeRead m (x,j) >>= MM.unsafeWrite m (x,j) . (+ mat M.! (i,j))
        -- normalize
        forM_ [0..k-1] $ \i -> do
            c <- GM.unsafeRead count i
            forM_ [0..d-1] $ \j ->
                MM.unsafeRead m (i,j) >>= MM.unsafeWrite m (i,j) . (/fromIntegral c)
        return m

    dist :: G.Vector v Double => v Double -> v Double -> Double
    dist xs = G.sum . G.zipWith (\x y -> (x - y)**2) xs

    n = M.rows mat
    k = M.rows initial
    d = M.cols mat
{-# INLINE kmeansWith #-}

-- * Initialization methods

-- | The Forgy method randomly chooses k unique observations from the data set and uses
-- these as the initial means
forgyMethod :: (PrimMonad m, G.Vector v a, Eq (v a))
            => Gen (PrimState m)
            -> Int                 -- number of clusters
            -> M.Matrix v a        -- data
            -> m (M.Matrix v a)
forgyMethod g k mat | k > n = error "k is larger than sample size"
                    | otherwise = iter
  where
    iter = do
        vec <- sample g k . U.enumFromN 0 $ n
        let xs = map (M.takeRow mat) . G.toList $ vec
        if length (nub xs) == length xs
           then return . M.fromRows $ xs
           else iter
    n = M.rows mat
{-# INLINE forgyMethod #-}

-- random select k samples from a population
sample :: (PrimMonad m, G.Vector v a) => Gen (PrimState m) -> Int -> v a -> m (v a)
sample g k xs = do
    v <- G.thaw xs
    forM_ [0..k-1] $ \i -> do
        j <- uniformR (i, lst) g
        GM.unsafeSwap v i j
    G.unsafeFreeze . GM.take k $ v
  where
    lst = G.length xs - 1
{-# INLINE sample #-}