{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
module AI.Clustering.KMeans
( KMeans(..)
, KMeansOpts(..)
, defaultKMeansOpts
, kmeans
, kmeansBy
, Method(..)
, decode
) where
import Control.Monad (forM_)
import Control.Monad.Primitive (PrimMonad, PrimState)
import qualified Data.Matrix.Unboxed as MU
import Data.Matrix.Class (unsafeTakeRow)
import qualified Data.Matrix.Unboxed.Mutable as MM
import Data.Ord (comparing)
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.List (minimumBy, foldl')
import System.Random.MWC (Gen, initialize)
import Control.Monad.ST (runST)
import AI.Clustering.KMeans.Types
import AI.Clustering.KMeans.Internal (sumSquares, forgy, kmeansPP)
kmeans :: Int
-> MU.Matrix Double
-> KMeansOpts
-> KMeans (U.Vector Double)
kmeans k mat opts
| containNaN = error "Input data contains NaN."
| otherwise = KMeans member cs grps sse'
where
containNaN = U.any isNaN $ MU.flatten mat
(member, cs, sse') = kmeans' initial (kmeansMaxIter opts) dat fn
grps = if kmeansClusters opts
then Just $ decode member $ MU.toRows mat
else Nothing
dat = U.enumFromN 0 $ MU.rows mat
fn = unsafeTakeRow mat
initial = runST $ do
gen <- initialize $ kmeansSeed opts
case kmeansMethod opts of
Forgy -> forgy gen k dat fn
KMeansPP -> kmeansPP gen k dat fn
Centers c -> return c
{-# INLINE kmeans #-}
kmeansBy :: G.Vector v a
=> Int
-> v a
-> (a -> U.Vector Double)
-> KMeansOpts
-> KMeans a
kmeansBy k dat fn opts
| containNaN = error "Input data contains NaN."
| otherwise = KMeans member cs grps sse'
where
containNaN = G.foldl (\acc x -> acc || U.any isNaN (fn x)) False dat
(member, cs, sse') = kmeans' initial (kmeansMaxIter opts) dat fn
grps = if kmeansClusters opts
then Just $ decode member $ G.toList dat
else Nothing
initial = runST $ do
gen <- initialize $ kmeansSeed opts
case kmeansMethod opts of
Forgy -> forgy gen k dat fn
KMeansPP -> kmeansPP gen k dat fn
Centers c -> return c
{-# INLINE kmeansBy #-}
kmeans' :: G.Vector v a
=> MU.Matrix Double
-> Int
-> v a
-> (a -> U.Vector Double)
-> (U.Vector Int, MU.Matrix Double, Double)
kmeans' initial maxiter dat fn
| U.length (fn $ G.head dat) /= d = error "Dimension mismatched."
| otherwise = (member, centers, U.sum $ U.imap ( \i x -> sqrt $ sumSquares
(fn $ dat G.! i) (centers `MU.takeRow` x) ) member )
where
(member, centers) = loop 0 initial U.empty
loop !iter means membership
| iter >= maxiter || membership' == membership = (membership, means)
| otherwise = loop (iter+1) (update membership') membership'
where
membership' = assign means
assign means = U.generate n $ \i ->
let x = fn $ G.unsafeIndex dat i
f (!min', !j') j = let d = sumSquares x $ means `unsafeTakeRow` j
in if d < min' then (d, j) else (min', j')
in snd $ foldl' f (1/0, -1) [0..k-1]
update membership = MU.create $ do
m <- MM.replicate (k,d) 0.0
count <- UM.replicate k (0 :: Int)
forM_ [0..n-1] $ \i -> do
let x = membership `U.unsafeIndex` i
vec = fn $ dat `G.unsafeIndex` i
UM.unsafeModify count (+1) x
forM_ [0..d-1] $ \j ->
MM.unsafeRead m (x,j) >>=
MM.unsafeWrite m (x,j) . (+ (vec `U.unsafeIndex` j))
forM_ [0..k-1] $ \i -> do
c <- UM.unsafeRead count i
forM_ [0..d-1] $ \j ->
MM.unsafeRead m (i,j) >>= MM.unsafeWrite m (i,j) . (/fromIntegral c)
return m
n = G.length dat
k = MU.rows initial
d = MU.cols initial
{-# INLINE kmeans' #-}
decode :: U.Vector Int -> [a] -> [[a]]
decode member xs = V.toList $ V.create $ do
v <- VM.replicate n []
forM_ (zip (U.toList member) xs) $ \(i,x) ->
VM.unsafeRead v i >>= VM.unsafeWrite v i . (x:)
return v
where
n = U.maximum member + 1
{-# INLINE decode #-}