module Data.KMeans (kmeans, kmeans')
where
import Data.List (transpose, sort, groupBy, minimumBy)
import Data.Function (on)
import Data.Ord (comparing)
type Vector a = [a]
dist a b = sqrt . sum $ zipWith (\x y-> (xy) ^ 2) a b
centroid points = map (flip (/) l . sum) $ transpose points
where l = fromIntegral $ length points
closest points point = minimumBy (comparing $ dist point) points
recluster' centroids points = map (map snd) $ groupBy ((==) `on` fst) reclustered
where reclustered = sort [(closest centroids a, a) | a <- points]
recluster clusters = recluster' centroids $ concat clusters
where centroids = map centroid clusters
part x ys
| zs' == [] = [zs]
| otherwise = zs : part x zs'
where (zs, zs') = splitAt x ys
kmeans' :: (Floating a, Ord a) => [[Vector a]] -> [[Vector a]]
kmeans' clusters
| clusters == clusters' = clusters
| otherwise = kmeans' clusters'
where clusters' = recluster clusters
kmeans :: (Floating a, Ord a) => Int -> [Vector a] -> [[Vector a]]
kmeans k points = kmeans' $ part l points
where l = (length points + k 1) `div` k