{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} module AI.Clustering.KMeans.Internal {-# WARNING "To be used by developer only" #-} ( forgy , kmeansPP , sumSquares ) where import Control.Monad.Primitive (PrimMonad, PrimState) import qualified Data.HashSet as S import Data.List (nub) import qualified Data.Matrix.Unboxed as MU import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as U import System.Random.MWC (Gen, uniformR) import System.Random.MWC.Distributions (categorical) forgy :: (PrimMonad m, G.Vector v a) => Gen (PrimState m) -> Int -- ^ The number of clusters -> v a -- ^ Input data -> (a -> U.Vector Double) -- ^ Feature extraction function -> m (MU.Matrix Double) forgy g k dat fn | k > n = error "k is larger than sample size" | otherwise = loop where loop = do vec <- uniformRN (0, n-1) k g let xs = map (fn . G.unsafeIndex dat) vec if length (nub xs) == length xs then return $ MU.fromRows xs else loop n = G.length dat {-# INLINE forgy #-} kmeansPP :: (PrimMonad m, G.Vector v a) => Gen (PrimState m) -> Int -- ^ The number of clusters -> v a -- ^ Input data -> (a -> U.Vector Double) -- ^ Feature extraction function -> m (MU.Matrix Double) kmeansPP g k dat fn | k > n = error "k is larger than sample size" | otherwise = do c1 <- uniformR (0,n-1) g loop [c1] 1 where loop centers !k' | k' == k = return $ MU.fromRows $ map (fn . G.unsafeIndex dat) centers | otherwise = do c' <- flip categorical g $ U.generate n $ \i -> minimum $ map (\c -> sumSquares (fn $ G.unsafeIndex dat i) (fn $ G.unsafeIndex dat c)) centers loop (c':centers) (k'+1) n = G.length dat {-# INLINE kmeansPP #-} sumSquares :: U.Vector Double -> U.Vector Double -> Double sumSquares xs = U.sum . U.zipWith (\x y -> (x - y) * (x - y)) xs {-# INLINE sumSquares #-} -- | Generate N non-duplicated uniformly distributed random variables in a given range. uniformRN :: PrimMonad m => (Int, Int) -> Int -> Gen (PrimState m) -> m [Int] uniformRN (lo, hi) n g | hi - lo + 1 < n = error "Range is too narrow!" | otherwise = loop S.empty where loop m | S.size m >= n = return $ S.toList m | otherwise = do x <- uniformR (lo,hi) g if x `S.member` m then loop m else loop $ S.insert x m