--------------------------------------------------------------------------------
-- |
-- Module      :  AI.Clustering.Hierarchical.Internal
-- Copyright   :  (c) 2015 Kai Zhang
-- License     :  MIT
--
-- Maintainer  :  kai@kzhang.org
-- Stability   :  experimental
-- Portability :  portable
--
--------------------------------------------------------------------------------
module AI.Clustering.Hierarchical.Internal
{-# WARNING "To be used by developer only" #-}
    ( nnChain
    , single
    , complete
    , average
    , weighted
    , ward
    ) where

import Control.Monad (forM_, when)
import qualified Data.Map as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import AI.Clustering.Hierarchical.Types

type ActiveNodeSet = M.Map Int (Dendrogram Int)
type DistUpdateFn = Int -> Int -> ActiveNodeSet -> DistanceMat -> DistanceMat

-- | nearest neighbor chain algorithm
nnChain :: DistanceMat -> DistUpdateFn -> Dendrogram Int
nnChain (DistanceMat n dist) fn = go (DistanceMat n $ U.force dist) initSet []
  where
    go ds activeNodes chain@(b:a:rest)
        | M.size activeNodes == 1 = head . M.elems $ activeNodes
        | c == a = go ds' activeNodes' rest
        | otherwise = go ds activeNodes $ c : chain
      where
        (c,d) = nearestNeighbor ds b a activeNodes

        -- We always remove the node with smaller index. The other one will be
        -- used to represent the merged result
        activeNodes' = M.insert hi (Branch (size1+size2) d c1 c2)
                     . M.delete lo $ activeNodes

        ds' = fn lo hi activeNodes ds
        c1 = M.findWithDefault undefined lo activeNodes
        c2 = M.findWithDefault undefined hi activeNodes
        size1 = size c1
        size2 = size c2
        (lo,hi) = if a <= b then (a,b) else (b,a)
    go ds activeNodes _ = go ds activeNodes [b,a]
      where
        a = fst $ M.elemAt 0 activeNodes
        b = fst $ nearestNeighbor ds a (-1) activeNodes

    initSet = M.fromList . map (\i -> (i, Leaf i)) $ [0..n-1]
{-# INLINE nnChain #-}

nearestNeighbor :: DistanceMat                  -- ^ distance matrix
                -> Int                          -- ^ query
                -> Int                          -- ^ this would be selected if
                                                -- it achieves the minimal distance
                -> M.Map Int (Dendrogram Int)
                -> (Int, Double)
nearestNeighbor dist i preference = M.foldlWithKey' f (-1,1/0)
  where
    f (x,d) j _ | i == j = (x,d)  -- skip
                | d' < d = (j,d')
                | d' == d && j == preference = (j,d')
                | otherwise = (x,d)
      where d' = dist ! (i,j)
{-# INLINE nearestNeighbor #-}

-- | all update functions perform destructive updates, and hence should not be
-- called by end users

-- | single linkage update formula
single :: DistUpdateFn
single lo hi nodeset (DistanceMat n dist) = DistanceMat n $ U.create $ do
    v <- U.unsafeThaw dist
    forM_ (M.keys nodeset) $ \i -> when (i/= hi && i/=lo) $ do
        d_lo_i <- UM.unsafeRead v $ idx n i lo
        d_hi_i <- UM.unsafeRead v $ idx n i hi
        UM.unsafeWrite v (idx n i hi) $ min d_lo_i d_hi_i
    return v
{-# INLINE single #-}

-- | complete linkage update formula
complete :: DistUpdateFn
complete lo hi nodeset (DistanceMat n dist) = DistanceMat n $ U.create $ do
    v <- U.unsafeThaw dist
    forM_ (M.keys nodeset) $ \i -> when (i/= hi && i/=lo) $ do
        d_lo_i <- UM.unsafeRead v $ idx n i lo
        d_hi_i <- UM.unsafeRead v $ idx n i hi
        UM.unsafeWrite v (idx n i hi) $ max d_lo_i d_hi_i
    return v
{-# INLINE complete #-}

-- | average linkage update formula
average :: DistUpdateFn
average lo hi nodeset (DistanceMat n dist) = DistanceMat n $ U.create $ do
    v <- U.unsafeThaw dist
    forM_ (M.keys nodeset) $ \i -> when (i/= hi && i/=lo) $ do
        d_lo_i <- UM.unsafeRead v $ idx n i lo
        d_hi_i <- UM.unsafeRead v $ idx n i hi
        UM.unsafeWrite v (idx n i hi) $ f1 * d_lo_i + f2 * d_hi_i
    return v
  where
    s1 = fromIntegral . size . M.findWithDefault undefined lo $ nodeset
    s2 = fromIntegral . size . M.findWithDefault undefined hi $ nodeset
    f1 = s1 / (s1+s2)
    f2 = s2 / (s1+s2)
{-# INLINE average #-}

-- | weighted linkage update formula
weighted :: DistUpdateFn
weighted lo hi nodeset (DistanceMat n dist) = DistanceMat n $ U.create $ do
    v <- U.unsafeThaw dist
    forM_ (M.keys nodeset) $ \i -> when (i/= hi && i/=lo) $ do
        d_lo_i <- UM.unsafeRead v $ idx n i lo
        d_hi_i <- UM.unsafeRead v $ idx n i hi
        UM.unsafeWrite v (idx n i hi) $ (d_lo_i + d_hi_i) / 2
    return v
{-# INLINE weighted #-}

-- | ward linkage update formula
ward :: DistUpdateFn
ward lo hi nodeset (DistanceMat n dist) = DistanceMat n $ U.create $ do
    v <- U.unsafeThaw dist
    d_lo_hi <- UM.unsafeRead v $ idx n lo hi
    forM_ (M.toList nodeset) $ \(i,t) -> when (i/= hi && i/=lo) $ do
        let s3 = fromIntegral . size $ t
        d_lo_i <- UM.unsafeRead v $ idx n i lo
        d_hi_i <- UM.unsafeRead v $ idx n i hi
        UM.unsafeWrite v (idx n i hi) $
            ((s1+s3)*d_lo_i + (s2+s3)*d_hi_i - s3*d_lo_hi) / (s1+s2+s3)
    return v
  where
    s1 = fromIntegral . size . M.findWithDefault undefined lo $ nodeset
    s2 = fromIntegral . size . M.findWithDefault undefined hi $ nodeset
{-# INLINE ward #-}

{-
-- O(n^2) time, O(n) space. Minimum spanning tree algorithm for single linkage
mst :: [a] -> DistFn a -> Dendrogram a
mst xs fn = undefined
-}