-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
--     http://hackage.haskell.org/trac/ghc/wiki/WorkingConventions#Warnings
-- for details  

-----------------------------------------------------------------------------
-- |
-- Module      :  Clustering
-- Copyright   :  (c) Philipp Pribbernow
-- License     :  BSD-style (see the file libraries/base/LICENSE)
-- 
-- Maintainer  :  libraries@haskell.org
-- Stability   :  experimental
-- Portability :  portable
--
-- Hieraclus is a library that supports clustering of arbitrary elements in haskell. The difference to the already 
-- existing cluster library /hierarchical-clustering/ is the ability to work with abort criterias which allow an 
-- \"intelligent\" clustering. With the help of abort criterias the user can specify conditions that must be fulfilled
-- in order to stop the clustering process.
-- 
-- Another motivation of creating this library was to make the cluster process run in /O(n^2)/. However, the current 
-- implementation runs in /O(n^2 * log n)/. It has to be mentioned that the real runtime complexity tends to grow 
-- faster due to memory management, I guess. Some profiling showed that there is quite a big amount of memory 
-- spent managing the maps. The principle idea was not to work with a matrix, but with two maps instead. The 
-- first map holds the mappings from cluster pairs to distances, the second map vice versa, thus allowing to find 
-- the minimal distance in /O(log n)/ and not in /O(n^2)/. Two make things more efficient the data to be clustered
-- initially is transformed to vector space, as all clutering operations work in vector space. The actual clustering
-- thus is done with the vector representations of the input data, which finally are transformed back.
--
-- The above mentioned information for the abort criterias, the maps and the element-mappings are carried through
-- the cluster process in a cluster state. So the actual cluster process takes place within the state monad.
-- However, the library offers a function 'cluster' that is purely functional as it returns a tuple. 
-- First element of the tuple is the cluster result - simply implemented as list of list. 
-- The second element of the tuple holds the cluster information used by the abort criterias. 
-----------------------------------------------------------------------------
{-# LANGUAGE DoAndIfThenElse #-}
module Numeric.Statistics.Clustering.Clustering (
                    -- * Cluster State
                    ClusterState(..),
                    ClusterInfo(..),
                    ClusterResult,
                    
                    -- * Cluster Map
                    Cluster(..),
                    ClusterMap(..),
                    ID,
                    singleton,
                    fromList,
                    getCluster,
                    getClusterUnsafe,
                    mergeClusters,
                    extractClusterElements,
      
                    -- * Minimum and Combination Map
                    MinimumMap(..),
                    CombinationMap(..),
                    Pair(..),
                    
                    -- * Abort Criterias
                    noAbort,
                    maxTotal,
                    nCluster,
                    nSteps,
                    calinski,
                    ellbow,
                    
                    -- * Cluster Methods
                    DistanceFunction(..),
                    SimilarityFunction(..),
                    singleLinkage,
                    completeLinkage,
                    averageLinkage,
                    wardLinkage,
                    
                    -- ** Cluster Method Construction
                    pairwise,
                    clusterwise,
                    
                    -- ** Cost Functions
                    addition,
                    varianceSum,
                   
                    -- * Clustering Process
                    Transformation(..),
                    cluster,
                    runCluster    
                
                  ) where

-- this data structure is used to map cluster ids to clusters and has a
-- space complexity of /O(n)/.
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap

-- this data structure is used to map and has a space complexity of 
-- /O(n^2)/.
import Data.Map (Map)
import qualified Data.Map as Map

-- this data structure is used to store the calculated distances between the
-- clusters and thus forms represents the distance matrix
import Data.MultiSet (    
                        MultiSet (..)
                     )  
import qualified Data.MultiSet as MS
import Control.Monad.State
import Data.Maybe (fromJust)
import Numeric.Statistics.Clustering.VectorUtils (
                      Vector(..), 
                      meanSquareV,
                      average
                   )
import qualified Numeric.Statistics.Clustering.VectorUtils as VU

{----------------------------------------------------------------------------
  ClusterMap
-----------------------------------------------------------------------------} 

-- | the Cluster map serves to represent unions of elements. Therefore it maps
-- IDs to clusters.
type ClusterMap a = IntMap (Cluster a)

-- | Unique ID for a cluster
type ID = IntMap.Key

-- | a Cluster is represented as a list of Vectors
newtype Cluster a = Cluster {
                      vals :: [Vector a]
                    } deriving (Show)

-- | the resulting clusters are represented as a lists                    
type ClusterResult a = [[a]]                    
                    
-- | /O(1)/ 
-- creates a cluster with only one element 
singleton :: Maybe (Vector a) -> Cluster a 
singleton x = case x of
    Just e -> Cluster [e]
    otherwise -> Cluster []

-- | /O(n)/
-- creates clusters by a given map
fromList :: [Vector a] -> ClusterMap a
fromList = IntMap.fromList . zip [1..] . map (singleton . Just)

-- /O(min(n,W))/
-- return a cluster by a given "ID"
getCluster :: ClusterMap a -> ID -> Maybe (Cluster a)
getCluster m id = IntMap.lookup id m  
                        
-- /O(min(n,W))/
-- unsafe version of "getCluster"
getClusterUnsafe :: ClusterMap a -> ID -> (Cluster a)
getClusterUnsafe m = fromJust . getCluster m
    
-- /O(log n)/
-- | merge two clusters given by their ids and return a tuple.
-- The first element of the tuple is the new created cluster.
-- The second element is the new resulting cluster structure    
mergeClusters :: 
    ID -> 
    ID -> 
    ClusterMap a -> 
    State (ClusterState a b) (Cluster a, ClusterMap a, ClusterMap a)
mergeClusters i1 i2 m = do
      -- delete the id of the second cluster from cluster map
      let (oldval, newM) = 
                IntMap.updateLookupWithKey 
                    (\_ -> const Nothing) i2 m 
      case oldval of
        Nothing -> mkError $ "Cluster" ++ (show i2) ++ "not found"
        Just cl -> do
          -- delete the id of the second cluster from cluster map
          let (oldval', newM') = IntMap.updateLookupWithKey 
                                  (\_ -> const Nothing) i1 newM 
          case oldval' of
            Nothing  -> mkError $ "Cluster" ++ (show i1) ++ "not found"
            Just cl' -> do                
              -- insert new cluster that contains all values of i1 and i2
              let newCluster = Cluster $ (vals cl' ++ vals cl)
                  newM'' = IntMap.insert i1 newCluster newM'        
              return (newCluster, newM', newM'')
 
 
-- | extracts the original values from the cluster map. It runs in the state
-- monad as it needs the mapping of vectors to original values.
extractClusterElements :: Ord a => 
      ClusterMap a ->  
      State (ClusterState a b) [[b]]
extractClusterElements clumap = do
    cinfo' <- return . cinfo =<< get
    let
      assocs = idents cinfo'      
    return $ map (map (fromJust . (\v -> Map.lookup v assocs))) 
            (map vals $ IntMap.elems clumap)

    
{----------------------------------------------------------------------------
  MinimumMap
-----------------------------------------------------------------------------}

-- | the minimum map saves the distance matrix as a multi set, because a distance 
-- can occur more than one times. The set allows to find a distance pair 
-- by its ids and is used to find the minimum distance in /O(log n)/
-- Note: Alternatively one could use kind of a binary heap to find
-- the minimum distance in /O(1)/
-- Storage complexity is /O(n^2)/
type MinimumMap a = MultiSet (a, Pair ID)

-- | a pair of ID is used for mappings from and to distances between 
-- two clusters. 
type Pair a = (a,a)

-- | Like the minimum map but with the pairs as the keys, thus allowing
-- to find the distance of a given pair in /O(log n)/.
-- Storage complexity is /O(n^2)/
type CombinationMap a = Map (Pair ID) a

-- | a Cluster Function calculates the distance between two clusters
type ClusterFunction a = (Cluster a -> Cluster a -> a)

-- | the cluster state contains information about all relevant maps
-- that are needed for the clustering and information about the 
-- clustering process. The ClusterState is passed around withing
-- the state monad
data ClusterState a b = CS {    
                          minmap :: MinimumMap a,       -- ^ holds the mappings from distances to pairs 
                          combis :: CombinationMap a,   -- ^ holds the mappings from pairs to distances
                          cinfo  :: ClusterInfo a b     -- ^ holds information of the clustering process that is needed by the Abort Criterias
                        } deriving (Show)

-- | the cluster process produces information about the clustering after each step.
-- these information are given to functions that decide if the cluster process 
-- may continue or stop and return the results
data ClusterInfo a b = CI {
                         idents :: Map (Vector a) b,          -- ^ holds the mapping from the representation vectors to its actual objects
                         nElems :: Int,                       -- ^ the number of elements to be clustered 
                         cNew :: (Cluster a, [Cluster a]),    -- ^ the new created cluster and the all other clusters
                         costs :: a,                        -- ^ a quality factor of the current combining that indicates the \"costs\" of cNew  
                         total :: a,                       -- ^ the accmulated costs
                         cStep :: Int,                        -- ^ the current clustering step
                         cHistory :: [a]                      -- ^ holds a history of all costs
                       } deriving (Show) 





{----------------------------------------------------------------------------
  Abort Criterias
-----------------------------------------------------------------------------}   

-- | An AbortCriterium is a constraint for the clustering process
-- deciding how many cluster steps are to be done. After each cluster
-- step the abort criterim is asked. /True/ means abortion of clustering.
type AbortCriterium a b = ClusterInfo a b -> Bool

-- | no abortion means that the cluster process is only limited by its 
-- maximum number of possible steps that is: /n/ - 1 where /n/ is the
-- number of elements to be clustered
noAbort :: AbortCriterium a b
noAbort cInfo = cStep cInfo >= nElems cInfo

-- | defines the max. \"costs\" of a further combining of two clusters. 
-- This can be the increase of the euclidean distance e.g. as
-- well as the varianceSum
maxTotal :: Ord a => a -> AbortCriterium a b
maxTotal n cInfo = total cInfo > n
 
-- | sets a max. number of clusters 
nCluster :: Int -> AbortCriterium a b
nCluster n cInfo = n > (nElems cInfo - cStep cInfo)

-- | sets a number of steps that has to be done     
nSteps :: Int -> AbortCriterium a b
nSteps n cInfo = cStep cInfo > n
   
-- | defines a tolerance for the homogeneity of the clusters
-- that is the relation of the inner varianceSum of the recently 
-- created cluster and the outer varianceSum of all other clusters
-- Developed by Calinski and Habarasz, see: 
calinski :: (Ord a, Floating a) => a -> AbortCriterium a b
calinski tol cInfo = ( (outerV / (innerV)) * ((n-k) * (k-1)) ) > tol
  where
    k = fromIntegral $ cStep cInfo
    (newCluster,rest) = cNew cInfo 
    n = (fromIntegral $ nElems cInfo) - k   
    innerV = sum $ map (meanSquareV . vals) rest
    outerV = sum $ map (meanSquareV . ((++) $ vals newCluster) . vals) rest
             

-- | calculates the ellbow criterium that is to find a cluster steps
-- which costs are above average. The first parameter gives a number
-- of steps that are tolerated as a kind of stabilization phase. So if
-- minSteps is set to k than ellbow criterium starts calculation average
-- at step k+1. The second parameter gives the max. allowed multiple of 
-- average inclination             
ellbow :: (Ord a, Num a, Floating a) => Int -> a -> AbortCriterium a b
ellbow minSteps factor cInfo =  (cStep cInfo) > minSteps && 
                                (not $ null history) &&
                                currInc > factor * (histAvg oldIncls)
  where
    history = cHistory cInfo
    (currInc:oldIncls) = inclinations history
    inclinations [x] = [x]
    inclinations xs = zipWith (-) xs (tail xs) 
    histAvg [] = currInc
    histAvg [x] = x
    histAvg xs = average xs
             
             
             
{----------------------------------------------------------------------------
  Cluster Methods
-----------------------------------------------------------------------------}

-- | a distance function determines how to calculate the distance between two
-- vectors
type DistanceFunction a = Vector a -> Vector a -> a

-- | calculates the difference of two clusters by comparing them as a whole,
-- e.g. the sum of variances of the clusters can be used
type SimilarityFunction a = [Vector a] -> a


-- | /O(n^2 log n)/. 
-- Uses the single linkage method for clustering
singleLinkage :: (Ord a, Eq a) => DistanceFunction a -> ClusterFunction a
singleLinkage df c1 c2 = minimum $ pairwise df c1 c2

-- | /O(n^2 log n)/. 
-- Uses the complete linkage method for clustering
completeLinkage :: (Ord a, Eq a) => DistanceFunction a -> ClusterFunction a
completeLinkage df c1 c2 = maximum $ pairwise df c1 c2

-- | /O(n^2 log n)/. 
-- Uses the average linkage method for clustering
averageLinkage :: (Ord a, Floating a) => DistanceFunction a -> ClusterFunction a
averageLinkage df c1 c2 = average $ pairwise df c1 c2

-- | /O(n^2 log n)/. 
-- Uses the ward linkage method for clustering
wardLinkage :: (Ord a) => SimilarityFunction a -> ClusterFunction a
wardLinkage f = clusterwise f

{----------------------------------------------------------------------------
  Cluster Methods Construction
-----------------------------------------------------------------------------}

-- evaluates a given function for all possible element pairs of two clusters
pairwise :: Ord a => DistanceFunction a -> Cluster a -> Cluster a -> [a]
pairwise f e1 e2 = [ f x y | x <- vals e1, y <- vals e2 ]
 
-- evaluates a given function for two given clusters 
clusterwise :: SimilarityFunction a -> ClusterFunction a
clusterwise f c1 c2 = f $ (vals c1) ++ (vals c2) 

{----------------------------------------------------------------------------
  Cost functions
-----------------------------------------------------------------------------} 
-- | a cost function has to descide how the single results produced after each
-- clustering step can be accumlated.
-- First tupel element gives the costs of the current step. The second element
-- gives the accumulated costs
type CostFunction a = a -> a -> [[Vector a]] -> a

-- the several costs of clustering may simply be added
addition :: Num a => CostFunction a
addition total dist _ = total + dist
      
-- the determination of the costs are calculated by considering the 
-- overall varianceSum     
varianceSum :: Floating a => CostFunction a
varianceSum _ dist cs = sum $ map meanSquareV cs 
                       


{----------------------------------------------------------------------------
  Clustering
-----------------------------------------------------------------------------} 

-- | transforms the input data into a vector representation
type Transformation a b = (a -> Vector b)

-- executes the cluster process      
cluster :: (Ord a, Num a) => 
          Transformation b a -> 
          ClusterFunction a -> 
          CostFunction a ->
          [AbortCriterium a b] -> 
          [b] -> 
          (ClusterResult b, ClusterInfo a b)
cluster toVector f cf ac cs = 
            let
              (res,cstate) =
                  runState (
                              runCluster toVector f cf ac cs >>= 
                                  extractClusterElements
                           ) emptyState
            in (res, cinfo cstate)


{----------------------------------------------------------------------------
  Internal Functions
-----------------------------------------------------------------------------}
 
-- | /O(n^2)/ 
-- calculates the upper triangle matrix
allPairs :: Ord a => [a] -> [Pair a]
allPairs xs = [(x,y) | x <- xs, y <- xs, x < y]

-- | Evaluates a list of pairs of ids.
evalPairs :: Ord a => 
        ClusterMap a ->
        ClusterFunction a -> 
        [Pair ID] -> 
        State (ClusterState a b) ([(a, Pair ID)])
evalPairs clumap f tupels = do
         ctupels <- mapM ( \p@(id1,id2) -> do
                           let 
                             x' = getClusterUnsafe clumap id1
                             y' = getClusterUnsafe clumap id2
                           return (f x' y', p)
                         ) tupels
         return ctupels
 
 
-- | the main cluster routine that does most of the work              
clustering :: (Ord a, Num a) => 
        Int -> 
        ClusterFunction a -> 
        CostFunction a ->
        [AbortCriterium a b] ->
        ClusterMap a -> 
        State (ClusterState a b) (ClusterMap a)
clustering n f cf ac xs = do 
          cinfo' <- (\ci -> return ci{cStep = n}) . cinfo =<< get
          -- check abort criterias from left to right until one states true
          if ((not $ null ac) && (or $ map (\a -> a cinfo') ac)) || noAbort cinfo'
          then return xs
          else do 
          (dist,(k1,k2)) <- findMin -- O (log n)
          (newCluster,rest,xs') <- mergeClusters k1 k2 xs  -- O (log n)  
          let       
            total' = cf (total cinfo') dist (map vals $ IntMap.elems xs')
            toUpdate = (k1,k2) : updatePairs (IntMap.keys xs') k1 k2 -- O(n)      
            cinfo'' = cinfo' {
                    cNew = (newCluster, IntMap.elems rest),
                    costs = dist,
                    total = total',
                    cHistory = total' : cHistory cinfo'}
          if ((not $ null ac) && (or $ map (\a -> a cinfo'') ac)) || noAbort cinfo''
          then return xs
          else do
            adjustMaps xs' toUpdate f 
            modify $ \s -> s {cinfo = cinfo'' }     
            clustering (n+1) f cf ac xs'

  
-- | updates the combination- and minimum map after each clustering step
adjustMaps :: (Num a, Ord a) => 
        ClusterMap a -> 
        [Pair ID] -> 
        ClusterFunction a -> 
        State (ClusterState a b) ()
adjustMaps clumap allP@(cPair@(k1,k2):ks) f = do
          cstate <- get
          let 
            pairsWithKey1 = filter (\(a,b) -> a == k1 || b == k1) ks
            pairsWithKey2 = filter (\(a,b) -> a == k2 || b == k2) ks                
          updatedPairs <- evalPairs clumap f pairsWithKey1 -- all pairs that have to be recomputed
          upDistMV <- mapM (getClusterDistance clumap f) allP -- construct the tuples of the pairs to be deleted O(n)              
          let 
            minmap' = foldl (flip MS.delete) (minmap cstate) $ map swap upDistMV -- O (log n)
            minmap''= foldl (flip MS.insert) minmap' updatedPairs  -- O (n)                
            combis' = foldl (flip Map.delete) (combis cstate) $ cPair : pairsWithKey2
            combis''= foldl (\m (v,pos) -> Map.update (const $ Just v) pos m) combis' updatedPairs 
          modify $ \s -> s{minmap = minmap'', combis = combis''}              
          return ()
    
-- | /O(n)/
-- caclulates the pairs of clusters that has to be updated by giving the 
-- the two recently combined cluster ids 
updatePairs :: [ID] -> ID -> ID -> [Pair ID]
updatePairs xs a b = [ if x < y then (x,y) else (y,x) | 
                          x <- xs, y <- [a,b], 
                              x /= y && x /= a && 
                              x /= b ]


-- | an empty state intializing all maps with empty             
emptyState :: Num a => ClusterState a b            
emptyState = CS { combis = Map.empty,  
                  minmap = MS.empty,
                  cinfo  = emptyInfo
                }   
-- | initializes the cluster info with default values                
emptyInfo :: Num a => ClusterInfo a b
emptyInfo = CI Map.empty 0 (singleton Nothing,[]) 0 0 0 []           

-- | a wrapper for the acutal clustering function running in the
-- state monad receiving the needed parameters to transform them for it           
runCluster :: (Ord a, Num a) => 
          (b -> Vector a) ->
          ClusterFunction a -> 
          CostFunction a ->
          [AbortCriterium a b] ->
          [b] -> 
          State (ClusterState a b) (ClusterMap a)
runCluster toVector f cf ac xs = do
        let
          -- map values into vector space
          mappedValues = map toVector xs
          clumap = fromList $ mappedValues
        pairs <- evalPairs clumap f $ allPairs [1..length xs]
        modify $ \s -> s{ combis = Map.fromList $ map swap pairs, 
                          minmap = MS.fromList pairs,
                          cinfo  = emptyInfo{nElems = length xs,
                                             idents = Map.fromList $
                                                      zip mappedValues xs
                                            }  
                        }
        clustering 1 f cf ac clumap
         


-- | /O(log n)/ 
-- searches for the minimum distance in the minimum map                    
findMin :: State (ClusterState a b) (a, Pair ID)
findMin = return . MS.findMin . minmap =<< get 


-- | calculates the distance between two clusters given by their ids
getClusterDistance :: ClusterMap a -> 
          ClusterFunction a -> 
          (ID,ID) -> 
          State (ClusterState a b) (Pair ID,a)
getClusterDistance clumap f pair =   
     (\m -> case Map.lookup pair m of
          Nothing -> do
              let
                x' = getClusterUnsafe clumap $ fst pair
                y' = getClusterUnsafe clumap $ snd pair
                res = f x' y'
              modify $ \s -> s{combis = Map.insert pair res $ combis s} 
              return (pair, res)
          Just e -> return (pair, e)
      ) . combis =<< get                          

                                      
{----------------------------------------------------------------------------
  Helper functions
-----------------------------------------------------------------------------}                                       

-- | swaps the elements of a tuple
swap :: (a,b) -> (b, a)
swap (x,y) = (y,x)  
                
-- | creates an error message
mkError :: String -> a
mkError = error . (++) "Clustering: "