{-# LANGUAGE BangPatterns #-} module Data.Clustering.Hierarchical.Internal.DistanceMatrix (singleLinkage ,completeLinkage ,upgma ,fakeAverageLinkage ) where -- from base import Control.Monad (forM_) import Control.Monad.ST (ST, runST) import Data.Array (listArray, (!)) import Data.Array.ST (STArray, STUArray, newArray_, newListArray, readArray, writeArray) import Data.Function (on) import Data.List (delete, tails, (\\)) import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef) -- from containers import qualified Data.IntMap as IM -- from this package import Data.Clustering.Hierarchical.Internal.Types mkErr :: String -> a mkErr = error . ("Data.Clustering.Hierarchical.Internal.DistanceMatrix." ++) -- | Internal (to this package) type used to represent a cluster -- (of possibly just one element). The @key@ should be less than -- or equal to all elements of the cluster. data Cluster = Cluster { key :: {-# UNPACK #-} !Item -- ^ Element used as key. , size :: {-# UNPACK #-} !Int -- ^ At least one, the @key@. } deriving (Eq, Ord, Show) -- | An element of a cluster. type Item = IM.Key -- | Creates a singleton cluster. singleton :: Item -> Cluster singleton k = Cluster {key = k, size = 1} -- | /O(1)/. Joins two clusters, returns the 'key' that didn't -- become 'key' of the new cluster as well. Clusters are not -- monoid because we don't have 'mempty'. merge :: Cluster -> Cluster -> (Cluster, Item) merge c1 c2 = let (kl,km) = if key c1 < key c2 then (key c1, key c2) else (key c2, key c1) in (Cluster {key = kl ,size = size c1 + size c2} ,km) -- | A distance matrix. data DistMatrix s = DM { matrix :: {-# UNPACK #-} !(STUArray s (Item, Item) Distance) , active :: {-# UNPACK #-} !(STRef s [Item]) , clusters :: {-# UNPACK #-} !(STArray s Item Cluster) } -- | /O(n^2)/. Creates a list of possible combinations between -- the given elements. combinations :: [a] -> [(a,a)] combinations xs = [(a,b) | (a:as) <- tails xs, b <- as] -- | /O(n^2)/. Constructs a new distance matrix from a distance -- function and a number @n@ of elements. Elements will be drawn -- from @[1..n]@ fromDistance :: (Item -> Item -> Distance) -> Item -> ST s (DistMatrix s) fromDistance _ n | n < 2 = mkErr "fromDistance: n < 2 is meaningless" fromDistance dist n = do matrix_ <- newArray_ ((1,2), (n-1,n)) active_ <- newSTRef [1..n] forM_ (combinations [1..n]) $ \x -> writeArray matrix_ x (uncurry dist x) clusters_ <- newListArray (1,n) (map singleton [1..n]) return $ DM {matrix = matrix_ ,active = active_ ,clusters = clusters_} -- | /O(n^2)/. Returns the minimum distance of the distance -- matrix. The first key given is less than the second key. findMin :: DistMatrix s -> ST s ((Cluster, Cluster), Distance) findMin dm = readSTRef (active dm) >>= go1 where matrix_ = matrix dm choose b i m' = if m' < snd b then (i, m') else b go1 is@(i1:i2:_) = do di <- readArray matrix_ (i1, i2) -- initial ((b1, b2), d) <- go2 is ((i1, i2), di) c1 <- readArray (clusters dm) b1 c2 <- readArray (clusters dm) b2 return ((c1, c2), d) go1 _ = mkErr "findMin: empty DistMatrix" go2 (i1:is@(_:_)) !b = go3 i1 is b >>= go2 is go2 _ b = return b go3 i1 (i2:is) !b = readArray matrix_ (i1,i2) >>= go3 i1 is . choose b (i1,i2) go3 _ [] b = return b -- | Type for functions that calculate distances between -- clusters. type ClusterDistance = (Cluster, Distance) -- ^ Cluster B1 and distance from A to B1 -> (Cluster, Distance) -- ^ Cluster B2 and distance from A to B2 -> Distance -- ^ Distance from A to (B1 U B2). -- Some cluster distances cdistSingleLinkage :: ClusterDistance cdistSingleLinkage = \(_, d1) (_, d2) -> d1 `min` d2 cdistCompleteLinkage :: ClusterDistance cdistCompleteLinkage = \(_, d1) (_, d2) -> d1 `max` d2 cdistUPGMA :: ClusterDistance cdistUPGMA = \(b1,d1) (b2,d2) -> let n1 = fromIntegral (size b1) n2 = fromIntegral (size b2) in (n1 * d1 + n2 * d2) / (n1 + n2) cdistFakeAverageLinkage :: ClusterDistance cdistFakeAverageLinkage = \(_, d1) (_, d2) -> (d1 + d2) / 2 -- | /O(n)/. Merges two clusters, returning the new cluster and -- the new distance matrix. mergeClusters :: ClusterDistance -> DistMatrix s -> (Cluster, Cluster) -> ST s Cluster mergeClusters cdist (DM matrix_ active_ clusters_) (b1, b2) = do let (bu, kl) = b1 `merge` b2 b1k = key b1 b2k = key b2 km = key bu ix i j | i < j = (i,j) | otherwise = (j,i) -- Calculate new distances activeV <- readSTRef active_ forM_ (activeV \\ [b1k, b2k]) $ \k -> do -- a <- readArray clusters_ k d_a_b1 <- readArray matrix_ $ ix k b1k d_a_b2 <- readArray matrix_ $ ix k b2k let d = cdist (b1, d_a_b1) (b2, d_a_b2) writeArray matrix_ (ix k km) $! d -- Save new cluster, invalidate old one writeArray clusters_ km bu writeArray clusters_ kl $ mkErr "mergeClusters: invalidated" writeSTRef active_ $ delete kl activeV -- Return new cluster. return bu -- | Worker function to create dendrograms based on a -- 'ClusterDistance'. dendrogram' :: ClusterDistance -> [a] -> (a -> a -> Distance) -> Dendrogram a dendrogram' _ [] _ = mkErr "dendrogram': empty input list" dendrogram' _ [x] _ = Leaf x dendrogram' cdist items dist = runST (act ()) where n = length items act _noMonomorphismRestrictionPlease = do let xs = listArray (1, n) items im = IM.fromDistinctAscList $ zip [1..] $ map Leaf items fromDistance (dist `on` (xs !)) n >>= go (n-1) im go !i !ds !dm = do ((c1,c2), distance) <- findMin dm cu <- mergeClusters cdist dm (c1,c2) let dendro c = IM.updateLookupWithKey (\_ _ -> Nothing) (key c) (Just d1, !ds') = dendro c1 ds (Just d2, !ds'') = dendro c2 ds' du = Branch distance d1 d2 case i of 1 -> return du _ -> let !ds''' = IM.insert (key cu) du ds'' in du `seq` go (i-1) ds''' dm -- | /O(n^3)/ time and /O(n^2)/ space. Calculates a complete, -- rooted dendrogram for a list of items using single linkage -- with the naïve algorithm using a distance matrix. singleLinkage :: [a] -> (a -> a -> Distance) -> Dendrogram a singleLinkage = dendrogram' cdistSingleLinkage -- | /O(n^3)/ time and /O(n^2)/ space. Calculates a complete, -- rooted dendrogram for a list of items using complete linkage -- with the naïve algorithm using a distance matrix. completeLinkage :: [a] -> (a -> a -> Distance) -> Dendrogram a completeLinkage = dendrogram' cdistCompleteLinkage -- | /O(n^3)/ time and /O(n^2)/ space. Calculates a complete, -- rooted dendrogram for a list of items using UPGMA with the -- naïve algorithm using a distance matrix. upgma :: [a] -> (a -> a -> Distance) -> Dendrogram a upgma = dendrogram' cdistUPGMA -- | /O(n^3)/ time and /O(n^2)/ space. Calculates a complete, -- rooted dendrogram for a list of items using fake average -- linkage with the naïve algorithm using a distance matrix. fakeAverageLinkage :: [a] -> (a -> a -> Distance) -> Dendrogram a fakeAverageLinkage = dendrogram' cdistFakeAverageLinkage