{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}

module AI.Clustering.KMeans
    ( KMeans(..)
    , KMeansOpts(..)
    , defaultKMeansOpts
    , kmeans
    , kmeansBy

    -- * Initialization methods
    , Method(..)

    , decode

    -- * References
    -- $references
    ) where

import Control.Monad (forM_)
import Control.Monad.Primitive (PrimMonad, PrimState)
import qualified Data.Matrix.Unboxed as MU
import Data.Matrix.Class (unsafeTakeRow)
import qualified Data.Matrix.Unboxed.Mutable as MM
import Data.Ord (comparing)
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.List (minimumBy, foldl')
import System.Random.MWC (Gen, initialize)
import Control.Monad.ST (runST)

import AI.Clustering.KMeans.Types
import AI.Clustering.KMeans.Internal (sumSquares, forgy, kmeansPP)

-- | Perform K-means clustering
kmeans :: Int                -- ^ The number of clusters
       -> MU.Matrix Double   -- ^ Input data stored as rows in a matrix
       -> KMeansOpts
       -> KMeans (U.Vector Double)
kmeans k mat opts
    | containNaN = error "Input data contains NaN."
    | otherwise = KMeans member cs grps sse'
  where
    containNaN =  U.any isNaN $ MU.flatten mat
    (member, cs, sse') = kmeans' initial (kmeansMaxIter opts) dat fn
    grps = if kmeansClusters opts
        then Just $ decode member $ MU.toRows mat
        else Nothing
    dat = U.enumFromN 0 $ MU.rows mat
    fn = unsafeTakeRow mat
    initial = runST $ do
        gen <- initialize $ kmeansSeed opts
        case kmeansMethod opts of
            Forgy -> forgy gen k dat fn
            KMeansPP -> kmeansPP gen k dat fn
            Centers c -> return c
{-# INLINE kmeans #-}

-- | Perform K-means clustering, using a feature extraction function
kmeansBy :: G.Vector v a
         => Int                   -- ^ The number of clusters
         -> v a                   -- ^ Input data
         -> (a -> U.Vector Double)
         -> KMeansOpts
         -> KMeans a
kmeansBy k dat fn opts
    | containNaN = error "Input data contains NaN."
    | otherwise = KMeans member cs grps sse'
  where
    containNaN = G.foldl (\acc x -> acc || U.any isNaN (fn x)) False dat
    (member, cs, sse') = kmeans' initial (kmeansMaxIter opts) dat fn
    grps = if kmeansClusters opts
        then Just $ decode member $ G.toList dat
        else Nothing
    initial = runST $ do
        gen <- initialize $ kmeansSeed opts
        case kmeansMethod opts of
            Forgy -> forgy gen k dat fn
            KMeansPP -> kmeansPP gen k dat fn
            Centers c -> return c
{-# INLINE kmeansBy #-}

-- | K-means algorithm
kmeans' :: G.Vector v a
        => MU.Matrix Double         -- ^ Initial set of k centroids
        -> Int                      -- ^ Max inter
        -> v a                      -- ^ Input data
        -> (a -> U.Vector Double)   -- ^ Feature extraction function
        -> (U.Vector Int, MU.Matrix Double, Double)
kmeans' initial maxiter dat fn
    | U.length (fn $ G.head dat) /= d = error "Dimension mismatched."
    | otherwise = (member, centers, U.sum $ U.imap ( \i x -> sqrt $ sumSquares
        (fn $ dat G.! i) (centers `MU.takeRow` x) ) member )
  where
    (member, centers) = loop 0 initial U.empty
    loop !iter means membership
        | iter >= maxiter || membership' == membership = (membership, means)
        | otherwise = loop (iter+1) (update membership') membership'
      where
        membership' = assign means

    -- Assignment step
    assign means = U.generate n $ \i ->
        let x = fn $ G.unsafeIndex dat i
            f (!min', !j') j = let d = sumSquares x $ means `unsafeTakeRow` j
                               in if d < min' then (d, j) else (min', j')
        in snd $ foldl' f (1/0, -1) [0..k-1]

    -- Update step
    update membership = MU.create $ do
        m <- MM.replicate (k,d) 0.0
        count <- UM.replicate k (0 :: Int)
        forM_ [0..n-1] $ \i -> do
            let x = membership `U.unsafeIndex` i
                vec = fn $ dat `G.unsafeIndex` i
            UM.unsafeModify count (+1) x
            forM_ [0..d-1] $ \j ->
                MM.unsafeRead m (x,j) >>=
                    MM.unsafeWrite m (x,j) . (+ (vec `U.unsafeIndex` j))
        -- normalize
        forM_ [0..k-1] $ \i -> do
            c <- UM.unsafeRead count i
            forM_ [0..d-1] $ \j ->
                MM.unsafeRead m (i,j) >>= MM.unsafeWrite m (i,j) . (/fromIntegral c)
        return m
    n = G.length dat
    k = MU.rows initial
    d = MU.cols initial
{-# INLINE kmeans' #-}

-- | Assign data to clusters based on KMeans result
decode :: U.Vector Int -> [a] -> [[a]]
decode member xs = V.toList $ V.create $ do
    v <- VM.replicate n []
    forM_ (zip (U.toList member) xs) $ \(i,x) ->
        VM.unsafeRead v i >>= VM.unsafeWrite v i . (x:)
    return v
  where
    n = U.maximum member + 1
{-# INLINE decode #-}

-- $references
--
-- Arthur, D. and Vassilvitskii, S. (2007). k-means++: the advantages of careful
-- seeding. Proceedings of the eighteenth annual ACM-SIAM symposium on Discrete
-- algorithms. Society for Industrial and Applied Mathematics Philadelphia, PA,
-- USA. pp. 1027–1035.