module Data.Clustering.Hierarchical.Internal.Optimal
( singleLinkage
, completeLinkage
) where
import Prelude hiding (pi)
import Control.Applicative ((<$>))
import Control.Arrow (first)
import Control.Monad (forM_, liftM3, when)
import Control.Monad.ST (ST, runST)
import Data.Array (Array, listArray, (!))
import Data.Array.ST (STUArray, newArray_, newListArray,
readArray, writeArray,
getElems, getBounds)
import Data.List (sortBy)
import Data.Maybe (fromMaybe)
import qualified Data.IntMap as IM
import Data.Clustering.Hierarchical.Internal.Types
mkErr :: String -> a
mkErr = error . ("Data.Clustering.Hierarchical.Internal.Optimal." ++)
type Index = Int
data PointerRepresentation s a =
PR { pi :: !(STUArray s Index Index)
, lambda :: !(STUArray s Index Distance)
, em :: !(STUArray s Index Distance)
, elm :: !(Array Index a)
}
initPR :: Index -> Array Index a -> ST s (PointerRepresentation s a)
initPR n xs' = ($ xs') <$> liftM3 PR (newArray_ (1, n)) (newArray_ (1, n)) (newArray_ (1, n))
indexDistance :: [a] -> (a -> a -> Distance)
-> (Index, Array Index a, Index -> Index -> Distance)
indexDistance xs dist = (n, xs', dist')
where
!n = length xs
!xs' = listArray (1, n) xs
dist' i j = dist (xs' ! i) (xs' ! j)
infinity :: Distance
infinity = 1 / 0
slink :: [a] -> (a -> a -> Distance) -> ST s (PointerRepresentation s a)
slink xs dist = initPR n xs' >>= go 1
where
(n, xs', dist') = indexDistance xs dist
go !i !pr | i == n + 1 = return pr
| otherwise = do
writeArray (pi pr) i i
writeArray (lambda pr) i infinity
forM_ [1..i1] $ \j ->
writeArray (em pr) j (dist' j i)
forM_ [1..i1] $ \j -> do
lambda_j <- readArray (lambda pr) j
em_j <- readArray (em pr) j
pi_j <- readArray (pi pr) j
em_pi_j <- readArray (em pr) pi_j
if lambda_j >= em_j then do
writeArray (em pr) pi_j (em_pi_j `min` lambda_j)
writeArray (lambda pr) j em_j
writeArray (pi pr) j i
else
writeArray (em pr) pi_j (em_pi_j `min` em_j)
forM_ [1..i1] $ \j -> do
pi_j <- readArray (pi pr) j
lambda_j <- readArray (lambda pr) j
lambda_pi_j <- readArray (lambda pr) pi_j
when (lambda_j >= lambda_pi_j) $
writeArray (pi pr) j i
go (i+1) pr
clink :: [a] -> (a -> a -> Distance) -> ST s (PointerRepresentation s a)
clink xs dist = initPR n xs' >>= go 1
where
(n, xs', dist') = indexDistance xs dist
go !i !pr | i == n + 1 = return pr
| i == 1 = do writeArray (pi pr) 1 1
writeArray (lambda pr) 1 infinity
go 2 pr
| otherwise = do
writeArray (pi pr) i i
writeArray (lambda pr) i infinity
forM_ [1..i1] $ \j ->
writeArray (em pr) j (dist' j i)
forM_ [1..i1] $ \j -> do
lambda_j <- readArray (lambda pr) j
em_j <- readArray (em pr) j
when (lambda_j < em_j) $ do
pi_j <- readArray (pi pr) j
em_pi_j <- readArray (em pr) pi_j
writeArray (em pr) pi_j (em_pi_j `max` em_j)
writeArray (em pr) j infinity
a <- readArray (em pr) (i1) >>= go_a_loop (i1) pr (i1)
b <- readArray (pi pr) a
c <- readArray (lambda pr) a
writeArray (pi pr) a i
writeArray (lambda pr) a =<< readArray (em pr) a
go_b_loop i pr a b c
forM_ [1..i1] $ \j -> do
pi_j <- readArray (pi pr) j
pi_pi_j <- readArray (pi pr) pi_j
when (pi_pi_j == i) $ do
lambda_j <- readArray (lambda pr) j
lambda_pi_j <- readArray (lambda pr) pi_j
when (lambda_j >= lambda_pi_j) $
writeArray (pi pr) j i
go (i+1) pr
go_a_loop 0 _ a _ = return a
go_a_loop !j !pr !a !em_a = do
pi_j <- readArray (pi pr) j
lambda_j <- readArray (lambda pr) j
em_pi_j <- readArray (em pr) pi_j
if lambda_j >= em_pi_j then do
em_j <- readArray (em pr) j
if em_j < em_a then
go_a_loop (j1) pr j em_j
else
go_a_loop (j1) pr a em_a
else do
writeArray (em pr) j infinity
go_a_loop (j1) pr a em_a
go_b_loop !i !pr !a !b !c
| a >= i 1 = return ()
| b < i 1 = do pi_b <- readArray (pi pr) b
lambda_b <- readArray (lambda pr) b
writeArray (pi pr) b i
writeArray (lambda pr) b c
go_b_loop i pr a pi_b lambda_b
| otherwise = do writeArray (pi pr) b i
writeArray (lambda pr) b c
return ()
buildDendrogram :: PointerRepresentation s a
-> ST s (Dendrogram a)
buildDendrogram pr = do
(1,n) <- getBounds (lambda pr)
lambdas <- getElems (lambda pr)
pis <- getElems (pi pr)
let sorted = sortBy (\(_,l1,_) (_,l2,_) -> l1 `compare` l2) $
zip3 [1..] lambdas pis
index <- newListArray (1,n) [1..]
let go im [] =
case IM.toList im of
[(_,x)] -> return x
_ -> mkErr "buildDendrogram: final never here"
go im ((i, (j,lambda_j,pi_j)):rest) = do
left_i <- readArray index j
right_i <- readArray index pi_j
writeArray (index `asTypeOf` pi pr) pi_j (negate i)
let (left, im') | left_i > 0 = (Leaf $ elm pr ! left_i, im)
| otherwise = first (fromMaybe e1) $
IM.updateLookupWithKey (\_ _ -> Nothing) ix im
where ix = negate left_i
(right, im'') | right_i > 0 = (Leaf $ elm pr ! right_i, im')
| otherwise = first (fromMaybe e2) $
IM.updateLookupWithKey (\_ _ -> Nothing) ix im'
where ix = negate right_i
im''' = IM.insert i (Branch lambda_j left right) im''
e1 = mkErr "buildDendrogram: never here 1"
e2 = mkErr "buildDendrogram: never here 2"
go im''' rest
go IM.empty (zip [1..n1] sorted)
singleLinkage :: [a] -> (a -> a -> Distance) -> Dendrogram a
singleLinkage [] _ = mkErr "singleLinkage: empty input"
singleLinkage [x] _ = Leaf x
singleLinkage xs dist = runST (slink xs dist >>= buildDendrogram)
completeLinkage :: [a] -> (a -> a -> Distance) -> Dendrogram a
completeLinkage [] _ = mkErr "completeLinkage: empty input"
completeLinkage [x] _ = Leaf x
completeLinkage xs dist = runST (clink xs dist >>= buildDendrogram)