module AI.Clustering.KMeans
( kmeans
, kmeansWith
, forgyMethod
) where
import Control.Monad (forM_)
import Control.Monad.Primitive (PrimMonad, PrimState)
import qualified Data.Matrix.Generic as M
import qualified Data.Matrix.Generic.Mutable as MM
import Data.Ord (comparing)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import Data.List (minimumBy, nub)
import System.Random.MWC (uniformR, Gen)
kmeans :: (G.Vector v Double, G.Vector v Int, Eq (v Int), Eq (v Double), PrimMonad m)
=> Gen (PrimState m)
-> Int
-> M.Matrix v Double
-> m (v Int, M.Matrix v Double)
kmeans g k mat = do
initial <- forgyMethod g k mat
return $ kmeansWith mat initial
kmeansWith :: (G.Vector v Double, G.Vector v Int, Eq (v Int))
=> M.Matrix v Double
-> M.Matrix v Double
-> (v Int, M.Matrix v Double)
kmeansWith initial mat | d /= M.cols initial || k > n = error "check input"
| otherwise = loop initial G.empty
where
loop means membership
| membership' == membership = (membership, means)
| otherwise = loop (update membership') membership'
where
membership' = assign means
assign means = G.generate n $ \i ->
let x = M.takeRow mat i
in fst $ minimumBy (comparing snd) $ zip [0..k1] $ map (dist x) $ M.toRows means
update membership = MM.create $ do
m <- MM.replicate (k,d) 0.0
count <- UM.replicate k (0 :: Int)
forM_ [0..n1] $ \i -> do
let x = membership G.! i
GM.unsafeRead count x >>= GM.unsafeWrite count x . (+1)
forM_ [0..d1] $ \j ->
MM.unsafeRead m (x,j) >>= MM.unsafeWrite m (x,j) . (+ mat M.! (i,j))
forM_ [0..k1] $ \i -> do
c <- GM.unsafeRead count i
forM_ [0..d1] $ \j ->
MM.unsafeRead m (i,j) >>= MM.unsafeWrite m (i,j) . (/fromIntegral c)
return m
dist :: G.Vector v Double => v Double -> v Double -> Double
dist xs = G.sum . G.zipWith (\x y -> (x y)**2) xs
n = M.rows mat
k = M.rows initial
d = M.cols mat
forgyMethod :: (PrimMonad m, G.Vector v a, Eq (v a))
=> Gen (PrimState m)
-> Int
-> M.Matrix v a
-> m (M.Matrix v a)
forgyMethod g k mat | k > n = error "k is larger than sample size"
| otherwise = iter
where
iter = do
vec <- sample g k . U.enumFromN 0 $ n
let xs = map (M.takeRow mat) . G.toList $ vec
if length (nub xs) == length xs
then return . M.fromRows $ xs
else iter
n = M.rows mat
sample :: (PrimMonad m, G.Vector v a) => Gen (PrimState m) -> Int -> v a -> m (v a)
sample g k xs = do
v <- G.thaw xs
forM_ [0..k1] $ \i -> do
j <- uniformR (i, lst) g
GM.unsafeSwap v i j
G.unsafeFreeze . GM.take k $ v
where
lst = G.length xs 1