module Data.Clustering.Hierarchical.Internal.DistanceMatrix
(Cluster(..)
,Item
,DistMatrix(..)
,ClusterDistance
,fromDistance
,findMin
,mergeClusters
) where
import qualified Data.IntMap as IM
import Control.Monad (forM_, when)
import Control.Monad.ST (ST)
import Data.Array.ST (STArray, newArray, newListArray, readArray, writeArray)
import Data.List (delete, tails)
import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef)
mkErr :: String -> a
mkErr = error . ("Data.Clustering.Hierarchical.Internal.DistanceMatrix." ++)
data Cluster = Cluster {key :: !Item
,more :: [Item]
,size :: !Int
}
deriving (Eq, Ord, Show)
type Item = IM.Key
singleton :: Item -> Cluster
singleton k = Cluster {key = k, more = [], size = 1}
merge :: Cluster -> Cluster -> (Cluster, Item)
merge c1 c2 = let (kl,km) = if key c1 < key c2
then (key c1, key c2)
else (key c2, key c1)
in (Cluster {key = kl
,more = km : more c1 ++ more c2
,size = size c1 + size c2}
,km)
data DistMatrix s d = DM {matrix :: STArray s (Item, Item) d
,active :: STRef s [Item]
,clusters :: STArray s Item Cluster}
combinations :: [a] -> [(a,a)]
combinations xs = [(a,b) | (a:as) <- tails xs, b <- as]
fromDistance :: Ord d => (Item -> Item -> d) -> Item -> ST s (DistMatrix s d)
fromDistance _ n | n < 2 = mkErr "fromDistance: n < 2 is meaningless"
fromDistance dist n = do
matrix_ <- newArray ((1,2), (n1,n)) (mkErr "fromDistance: undef element")
active_ <- newSTRef [1..n]
forM_ (combinations [1..n]) $ \x -> writeArray matrix_ x (uncurry dist x)
clusters_ <- newListArray (1,n) (map singleton [1..n])
return $ DM {matrix = matrix_
,active = active_
,clusters = clusters_}
findMin :: Ord d => DistMatrix s d -> ST s ((Cluster, Cluster), d)
findMin dm = readSTRef (active dm) >>= go1 . combinations
where
matrix_ = matrix dm
choose b i m' = if m' < snd b then (i, m') else b
go1 (i:is) = readArray matrix_ i >>= go2 is . (,) i
go1 [] = mkErr "findMin: empty DistMatrix"
go2 (i:is) b = readArray matrix_ i >>= go2 is . choose b i
go2 [] b = do c1 <- readArray (clusters dm) (fst $ fst b)
c2 <- readArray (clusters dm) (snd $ fst b)
return ((c1, c2), snd b)
type ClusterDistance d =
Cluster
-> (Cluster, d)
-> (Cluster, d)
-> Cluster
-> d
mergeClusters :: (Ord d)
=> ClusterDistance d
-> DistMatrix s d
-> (Cluster, Cluster)
-> ST s Cluster
mergeClusters cdist (DM matrix_ active_ clusters_) (b1, b2) = do
let (bu, kl) = b1 `merge` b2
b1k = key b1
b2k = key b2
km = key bu
ix i j | i < j = (i,j)
| otherwise = (j,i)
activeV <- readSTRef active_
forM_ activeV $ \k -> when (k `notElem` [b1k, b2k]) $ do
a <- readArray clusters_ k
d_a_b1 <- readArray matrix_ $ ix k b1k
d_a_b2 <- readArray matrix_ $ ix k b2k
let d = cdist a (b1, d_a_b1) (b2, d_a_b2) bu
writeArray matrix_ (ix k km) d
writeArray clusters_ km bu
writeArray clusters_ kl $ mkErr "mergeClusters: invalidated"
writeSTRef active_ $ delete kl activeV
return bu