module AI.Clustering.Hierarchical.Types
    ( Distance
    , DistFn
    , Size
    , Dendrogram(..)
    , size
    , DistanceMat(..)
    , (!)
    , idx
    , computeDists
    , computeDists'
    ) where

import Control.Monad (liftM, liftM4)
import Control.Parallel.Strategies (rdeepseq, parMap)
import Data.Binary (Binary, put, get, getWord8)
import Data.Bits (shiftR)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Generic as G
import Data.Word (Word8)

type Distance = Double
type DistFn a = a -> a -> Distance
type Size = Int

data Dendrogram a = Leaf !a
                  | Branch !Size !Distance !(Dendrogram a) !(Dendrogram a)
    deriving (Show, Eq)

instance Binary a => Binary (Dendrogram a) where
    put (Leaf a) = do put (0 :: Word8)
                      put a
    put (Branch s d l r) = do put (1 :: Word8)
                              put s
                              put d
                              put l
                              put r

    get = do tag <- getWord8
             case tag of
                 0 -> liftM Leaf get
                 1 -> liftM4 Branch get get get get
                 _ -> error "fail to decode the dendrogram"

instance Functor Dendrogram where
    fmap f (Leaf x) = Leaf $ f x
    fmap f (Branch n d l r) = Branch n d (fmap f l) $ fmap f r

-- | O(1) Return the size of a dendrogram
size :: Dendrogram a -> Int
size (Leaf _) = 1
size (Branch n _ _ _) = n
{-# INLINE size #-}

-- upper triangular matrix
data DistanceMat = DistanceMat !Int !(U.Vector Double) deriving (Show)

(!) :: DistanceMat -> (Int, Int) -> Double
(!) (DistanceMat n v) (i',j') = v U.! idx n i' j'
{-# INLINE (!) #-}

idx :: Int -> Int -> Int -> Int
idx n i j | i <= j = (i * (2 * n - i - 3)) `shiftR` 1 + j - 1
          | otherwise = (j * (2 * n - j - 3)) `shiftR` 1 + i - 1
{-# INLINE idx #-}

-- | compute distance matrix
computeDists :: G.Vector v a => DistFn a -> v a -> DistanceMat
computeDists f vec = DistanceMat n . U.fromList . flip concatMap [0..n-1] $ \i ->
    flip map [i+1..n-1] $ \j -> f (vec `G.unsafeIndex` i) (vec `G.unsafeIndex` j)
  where
    n = G.length vec
{-# INLINE computeDists #-}

-- | compute distance matrix in parallel
computeDists' :: G.Vector v a => DistFn a -> v a -> DistanceMat
computeDists' f vec = DistanceMat n . U.fromList . concat . flip (parMap rdeepseq) [0..n-1] $ \i ->
    flip map [i+1..n-1] $ \j -> f (vec `G.unsafeIndex` i) (vec `G.unsafeIndex` j)
  where
    n = G.length vec
{-# INLINE computeDists' #-}