module Math.Statistics.KMeans ( euclideanDist , distanceToCenters , assignCluster , cloudCenter , selectFrom , kMeans ) where import qualified Data.Vector as V import qualified Data.Vector.Unboxed as UV -- | Euclidean Distance between two points euclideanDist :: (RealFloat a, UV.Unbox a) => UV.Vector a -> UV.Vector a -> a euclideanDist x y = sqrt . UV.sum $ UV.zipWith (\xi yi -> (xi - yi) ^ (2 :: Integer)) x y -- | Distance from a Point to a set of Centers distanceToCenters :: (t -> a -> b) -> V.Vector a -> t -> V.Vector b distanceToCenters distance centers point = V.map (distance point) centers -- | Assign Points to a Cluster based on the Minimum distance to each Center assignCluster :: Ord a => V.Vector (V.Vector a) -> V.Vector Int assignCluster = V.map V.minIndex -- | Calculates the Center of a Cloud of Points cloudCenter :: (Fractional a, UV.Unbox a) => V.Vector (UV.Vector a) -> UV.Vector a cloudCenter cloud = UV.map (/ fromIntegral(V.length cloud)) $ V.foldl1 (UV.zipWith (+)) cloud -- | Selects elements of a Vector given its indices selectFrom :: V.Vector a -> V.Vector Int -> V.Vector a selectFrom x = V.map (x V.!) -- | Checks if a Vector has duplicated elements hasDuplicates :: (Eq a) => V.Vector a -> Bool hasDuplicates a | (a == V.empty) = False | V.any (\y -> h == y) t == True = True | otherwise = hasDuplicates t where h = a V.! 0 t = V.tail a -- | k-Means classifier for a given Distance, Variation Guard and Cloud kMeans :: (RealFloat a, UV.Unbox a) => (UV.Vector a -> UV.Vector a -> a) -> a -> V.Vector (UV.Vector a) -> V.Vector (UV.Vector a) -> V.Vector (UV.Vector a) kMeans distance varGuard centers cloud | hasDuplicates centers == True = error "Non-unique centers provided, aborting." | otherwise = let dists = V.map (distanceToCenters distance centers) cloud assigned = assignCluster dists pointAssign = V.map (selectFrom cloud) $ V.fromList [V.elemIndices x assigned | x <- [0..(V.length centers - 1)] ] newcenters = V.map cloudCenter pointAssign variation = V.sum $ V.zipWith distance centers newcenters in (if variation > varGuard then kMeans distance varGuard newcenters cloud else newcenters)