{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
{- |
Module : Math.KMeans
Copyright : (c) Alp Mestanogullari, Ville Tirronen, 2011-2014
License : BSD3
Maintainer : Alp Mestanogullari
Stability : experimental
An implementation of the k-means clustering algorithm based on the efficient vector package.
-}
module OldKMeans (kmeans, Point, Cluster(..), computeClusters) where
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector as G
import qualified Data.List as L
import Data.Function (on)
--- * K-Means clustering algorithm
-- | Type holding an object of any type and its associated feature vector
type Point a = (V.Vector Double, a)
-- | Type representing a cluster (group) of vectors by its center and an id
data Cluster = Cluster {
cid :: {-# UNPACK #-} !Int, -- ^ an identifier for the cluster
center :: !(V.Vector Double) -- ^ the 'position' of the center of the cluster
} -- deriving (Show,Eq)
-- genVec = V.fromList `fmap` vectorOf 3 arbitrary
-- genPts = (flip zip) [0..] `fmap` replicateM 10 genVec
-- genClusters = do
-- cs <- replicateM 5 genVec
-- return (zipWith Cluster [0.. ] cs)
--
-- prop_regroup = forAll genClusters $ \c ->
-- forAll genPts $ \v ->
-- s (regroupPoints c v) == s (regroupPoints' c v)
-- where
-- same xs = length (L.nub xs) == length xs
-- s = map L.sort
{-# INLINE distance #-}
distance :: Point a -> V.Vector Double -> Double
distance (u,_) v = V.sum $ V.zipWith (\a b -> (a - b)^2) u v
partition :: Int -> [a] -> [[a]]
partition k vs = go vs
where go vs = case L.splitAt n vs of
(vs', []) -> [vs']
(vs', vss) -> vs' : go vss
n = (length vs + k - 1) `div` k
{-#INLINE computeClusters#-}
computeClusters :: [[V.Vector Double]] -> [Cluster]
computeClusters = zipWith Cluster [0..] . map f
where f (x:xs) = let (n, v) = L.foldl' (\(k, s) v' -> (k+1, V.zipWith (+) s v')) (1, x) xs
in V.map (\x -> x / (fromIntegral n)) v
{-#INLINE regroupPoints#-}
regroupPoints :: forall a. [Cluster] -> [Point a] -> [[Point a]]
regroupPoints clusters points = L.filter (not.null) . G.toList . G.accum (flip (:)) (G.replicate (length clusters) []) . map closest $ points
where
closest p = (cid (L.minimumBy (compare `on` (distance p . center)) clusters),p)
regroupPoints' :: [Cluster] -> [Point a] -> [[Point a]]
regroupPoints' clusters points = go points
where go points = map (map snd) . L.groupBy ((==) `on` fst) . L.sortBy (compare `on` fst) $ map (\p -> (closest p, p)) points
closest p = cid $ L.minimumBy (compare `on` (distance p . center)) clusters
kmeansStep :: [Point a] -> [[Point a]] -> [[Point a]]
kmeansStep points pgroups =
regroupPoints (computeClusters . map (map fst) $ pgroups) points
kmeansAux :: [Point a] -> [[Point a]] -> [[Point a]]
kmeansAux points pgroups = let pss = kmeansStep points pgroups in
-- has anything changed since the last step?
-- even a point jumping from one cluster to another is enough to
-- enter the 'False' case
case map (map fst) pss == map (map fst) pgroups of
True -> pgroups -- nothing's changed, we're done
False -> kmeansAux points pss -- something has changed, so let's try again
-- | Performs the k-means clustering algorithm
-- trying to use 'k' clusters on the given list of points
kmeans :: Int -> [Point a] -> [[Point a]]
kmeans k points = kmeansAux points pgroups
where pgroups = partition k points
{-# INLINE kmeans #-}