module AI.Clustering.Hierarchical.Types
( Distance
, DistFn
, Size
, Dendrogram(..)
, size
, DistanceMat(..)
, (!)
, idx
, computeDists
, computeDists'
) where
import Control.Monad (liftM, liftM4)
import Control.Parallel.Strategies (rdeepseq, parMap)
import Data.Binary (Binary, put, get, getWord8)
import Data.Bits (shiftR)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Generic as G
import Data.Word (Word8)
type Distance = Double
type DistFn a = a -> a -> Distance
type Size = Int
data Dendrogram a = Leaf !a
| Branch !Size !Distance !(Dendrogram a) !(Dendrogram a)
deriving (Show, Eq)
instance Binary a => Binary (Dendrogram a) where
put (Leaf a) = do put (0 :: Word8)
put a
put (Branch s d l r) = do put (1 :: Word8)
put s
put d
put l
put r
get = do tag <- getWord8
case tag of
0 -> liftM Leaf get
1 -> liftM4 Branch get get get get
_ -> error "fail to decode the dendrogram"
instance Functor Dendrogram where
fmap f (Leaf x) = Leaf $ f x
fmap f (Branch n d l r) = Branch n d (fmap f l) $ fmap f r
size :: Dendrogram a -> Int
size (Leaf _) = 1
size (Branch n _ _ _) = n
data DistanceMat = DistanceMat !Int !(U.Vector Double) deriving (Show)
(!) :: DistanceMat -> (Int, Int) -> Double
(!) (DistanceMat n v) (i',j') = v U.! idx n i' j'
idx :: Int -> Int -> Int -> Int
idx n i j | i <= j = (i * (2 * n i 3)) `shiftR` 1 + j 1
| otherwise = (j * (2 * n j 3)) `shiftR` 1 + i 1
computeDists :: G.Vector v a => DistFn a -> v a -> DistanceMat
computeDists f vec = DistanceMat n . U.fromList . flip concatMap [0..n1] $ \i ->
flip map [i+1..n1] $ \j -> f (vec `G.unsafeIndex` i) (vec `G.unsafeIndex` j)
where
n = G.length vec
computeDists' :: G.Vector v a => DistFn a -> v a -> DistanceMat
computeDists' f vec = DistanceMat n . U.fromList . concat . flip (parMap rdeepseq) [0..n1] $ \i ->
flip map [i+1..n1] $ \j -> f (vec `G.unsafeIndex` i) (vec `G.unsafeIndex` j)
where
n = G.length vec