-- |
-- Module      : ConClusion.Numeric.Statistics
-- Description : Statistical Functions
-- Copyright   : Phillip Seeber, 2021
-- License     : AGPL-3
-- Maintainer  : phillip.seeber@googlemail.com
-- Stability   : experimental
-- Portability : POSIX, Windows
module ConClusion.Numeric.Statistics
  ( -- * PCA
    PCA (..),
    pca,

    -- * Variance
    normalise,
    meanDeviation,
    covariance,

    -- * Distance Metrics
    DistFn,
    lpNorm,
    manhattan,
    euclidean,
    mahalanobis,

    -- * Cluster Algorithms
    Clusters,

    -- ** DBScan
    DistanceInvalidException (..),
    dbscan,

    -- ** Hierarchical Cluster Analysis
    Dendrogram,
    JoinStrat (..),
    hca,
    cutDendroAt,
  )
where

import ConClusion.Array.Conversion
import ConClusion.Array.Util hiding (normalise)
import ConClusion.BinaryTree
import Data.Aeson hiding (Array)
import Data.Complex
import qualified Data.HashPSQ as PQ
import qualified Data.IntSet as IntSet
import Data.Massiv.Array as Massiv
import Data.Massiv.Array.Unsafe as Massiv
import qualified Numeric.LinearAlgebra as LA
import RIO hiding (Vector)
import System.IO.Unsafe (unsafePerformIO)

----------------------------------------------------------------------------------------------------
-- Others/Helpers

-- | Solves eigenvalue problem of a square matrix and obtains its eigenvalues and eigenvectors.
{-# SCC eig #-}
eig ::
  ( -- Manifest r1 Ix1 (Complex Double),
    -- Manifest r2 Ix1 (Complex Double),
    LA.Field e,
    Manifest r3 e,
    Manifest r1 (Complex Double),
    Manifest r2 (Complex Double),
    Load r1 Ix1 (Complex Double),
    Load r2 Ix1 (Complex Double),
    Load r3 Ix1 e,
    MonadThrow m
  ) =>
  Matrix r3 e ->
  m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig :: forall e r3 r1 r2 (m :: * -> *).
(Field e, Manifest r3 e, Manifest r1 (Complex Double),
 Manifest r2 (Complex Double), Load r1 Ix1 (Complex Double),
 Load r2 Ix1 (Complex Double), Load r3 Ix1 e, MonadThrow m) =>
Matrix r3 e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig Matrix r3 e
covM
  | Ix1
m forall a. Eq a => a -> a -> Bool
/= Ix1
n = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"eigenvalue problems can only be solved for square matrix"
  | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall e r.
(Element e, Manifest r e, Load r Ix1 e) =>
Vector e -> Vector r e
vecH2M forall r e.
(Manifest r e, Load r Ix1 e, Element e) =>
Matrix e -> Matrix r e
matH2M forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t.
Field t =>
Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
LA.eig forall a b. (a -> b) -> a -> b
$ Matrix e
cov
  where
    Sz (Ix1
m :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r3 e
covM
    cov :: Matrix e
cov = forall e r.
(Element e, Manifest r e, Load r Ix1 e) =>
Matrix r e -> Matrix e
matM2H Matrix r3 e
covM

-- | Sort eigenvalues and eigenvectors by magnitude of the eigenvalues in descending order (largest
-- eigenvalues first). Eigenvectors are the columns of the input matrix.
{-# SCC eigSort #-}
eigSort ::
  ( Load r2 Ix2 e,
    MonadThrow m,
    Source r1 e,
    Source r2 e,
    Manifest r1 e,
    Manifest r2 e,
    Unbox e,
    Ord e
  ) =>
  (Vector r1 e, Matrix r2 e) ->
  m (Vector r1 e, Matrix r2 e)
eigSort :: forall r2 e (m :: * -> *) r1.
(Load r2 Ix2 e, MonadThrow m, Source r1 e, Source r2 e,
 Manifest r1 e, Manifest r2 e, Unbox e, Ord e) =>
(Vector r1 e, Matrix r2 e) -> m (Vector r1 e, Matrix r2 e)
eigSort (Vector r1 e
vec, Matrix r2 e
mat)
  | Ix1
m forall a. Eq a => a -> a -> Bool
/= Ix1
n = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"matrix of the eigenvectors is not a square matrix"
  | Ix1
n forall a. Eq a => a -> a -> Bool
/= Ix1
n' = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"different number of eigenvalues and eigenvectors"
  | Bool
otherwise = do
    let ixedEigenvalues :: Array D Ix1 (e, Ix1)
ixedEigenvalues = forall ix r1 e1 r2 e2.
(Index ix, Source r1 e1, Source r2 e2) =>
Array r1 ix e1 -> Array r2 ix e2 -> Array D ix (e1, e2)
Massiv.zip Vector r1 e
vec Array D Ix1 Ix1
ixVec
        (Array U Ix1 e
eigenValSortAsc, Array U Ix1 Ix1
ixSort) = (\Array U Ix1 (e, Ix1)
a -> (forall {e} {ix} {r} {e'}.
(Unbox e, Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array U ix e
get forall a b. (a, b) -> a
fst Array U Ix1 (e, Ix1)
a, forall {e} {ix} {r} {e'}.
(Unbox e, Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array U ix e
get forall a b. (a, b) -> b
snd Array U Ix1 (e, Ix1)
a)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. (Manifest r e, Ord e) => Vector r e -> Vector r e
quicksort forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall a b. (a -> b) -> a -> b
$ Array D Ix1 (e, Ix1)
ixedEigenvalues
        eigenVecSortAsc :: Array D Ix2 e
eigenVecSortAsc = forall r ix ix' e.
(HasCallStack, Source r e, Index ix, Index ix') =>
Sz ix' -> (ix' -> ix) -> Array r ix e -> Array D ix' e
backpermute' (forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ Ix1
m Ix1 -> Ix1 -> Ix2
:. Ix1
n) (\(Ix1
r :. Ix1
c) -> Ix1
r Ix1 -> Ix1 -> Ix2
:. (Array U Ix1 Ix1
ixSort forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
c)) Matrix r2 e
mat
        eigenValSort :: Array D Ix1 e
eigenValSort = forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Dim -> Array r ix e -> Array D ix e
reverse' (Ix1 -> Dim
Dim Ix1
1) Array U Ix1 e
eigenValSortAsc
        eigenVecSort :: Array D Ix2 e
eigenVecSort = forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Dim -> Array r ix e -> Array D ix e
reverse' (Ix1 -> Dim
Dim Ix1
1) Array D Ix2 e
eigenVecSortAsc
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix1 e
eigenValSort, forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix2 e
eigenVecSort)
  where
    Sz (Ix1
m :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r2 e
mat
    Sz Ix1
n' = forall r ix e. Size r => Array r ix e -> Sz ix
size Vector r1 e
vec
    ixVec :: Array D Ix1 Ix1
ixVec = forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (Ix1 -> e) -> Array r ix e
makeArrayLinear @D Comp
Seq (forall ix. Index ix => ix -> Sz ix
Sz Ix1
n') forall a. a -> a
id
    get :: (e' -> e) -> Array r ix e' -> Array U ix e
get e' -> e
acc = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map e' -> e
acc

-- | Adjust function for priority queues. Updates the priority at a given key if present.
pqAdjust :: (Ord k, Hashable k, Ord p) => (p -> p) -> k -> PQ.HashPSQ k p v -> PQ.HashPSQ k p v
pqAdjust :: forall k p v.
(Ord k, Hashable k, Ord p) =>
(p -> p) -> k -> HashPSQ k p v -> HashPSQ k p v
pqAdjust p -> p
f k
k HashPSQ k p v
q = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k p v b.
(Hashable k, Ord k, Ord p) =>
(Maybe (p, v) -> (b, Maybe (p, v)))
-> k -> HashPSQ k p v -> (b, HashPSQ k p v)
PQ.alter Maybe (p, v) -> (Bool, Maybe (p, v))
f' k
k HashPSQ k p v
q
  where
    f' :: Maybe (p, v) -> (Bool, Maybe (p, v))
f' = \Maybe (p, v)
op -> case Maybe (p, v)
op of
      Maybe (p, v)
Nothing -> (Bool
False, forall a. Maybe a
Nothing)
      Just (p
p, v
v) -> (Bool
False, forall a. a -> Maybe a
Just (p -> p
f p
p, v
v))

----------------------------------------------------------------------------------------------------
-- Principal Component Analysis

data PCA = PCA
  { -- | Original feature matrix.
    PCA -> Matrix U Double
x :: Matrix U Double,
    -- | Feature matrix in mean deviation form.
    PCA -> Matrix U Double
x' :: Matrix U Double,
    -- | Transformed data.
    PCA -> Matrix U Double
y :: Matrix U Double,
    -- | Transformation matrix to transform feature matrix into PCA result matrix.
    PCA -> Matrix U Double
a :: Matrix U Double,
    -- | Mean squared error introduced by PCA.
    PCA -> Double
mse :: Double,
    -- | Percentage of the behaviour captured in the remaining dimensions.
    PCA -> Double
remaining :: Double,
    -- | All eigenvalues from the diagonalisation of the covariance matrix.
    PCA -> Vector U Double
allEigenValues :: Vector U Double,
    -- | Eigenvalues that were kept for PCA.
    PCA -> Vector U Double
pcaEigenValues :: Vector U Double,
    -- | All eigenvectors from the diagonalisation of the covariance matrix.
    PCA -> Matrix U Double
allEigenVecs :: Matrix U Double,
    -- | Eigenvectors that were kept for PCA.
    PCA -> Matrix U Double
pcaEigenVecs :: Matrix U Double
  }

-- | Transform the input values with a transformation matrix \(\mathbf{A}\), where \(\mathbf{A}\) is
-- constructed from the eigenvectors associated to the largest eigenvalues.
{-# SCC transformToPCABasis #-}
transformToPCABasis ::
  ( -- Source (R r) Ix2 e,
    -- Extract r Ix2 e,
    Manifest r e,
    Numeric r e,
    MonadThrow m
  ) =>
  -- | Number of dimensions to keep from PCA.
  Int ->
  -- | Matrix of the eigenvectors, sorted descendingly by eigenvalues, where the eigenvectors are
  -- the columns of the matrix.
  Matrix r e ->
  -- | Feature matrix in mean deviation form.
  Matrix r e ->
  -- | Input data transformed by PCA to lower dimensions, and the transformation matrix
  -- \(\mathbf{A}\).
  m (Matrix r e, Matrix r e)
transformToPCABasis :: forall r e (m :: * -> *).
(Manifest r e, Numeric r e, MonadThrow m) =>
Ix1 -> Matrix r e -> Matrix r e -> m (Matrix r e, Matrix r e)
transformToPCABasis Ix1
nDim Matrix r e
eigenVecMat Matrix r e
featureMat
  | Ix1
mE forall a. Eq a => a -> a -> Bool
/= Ix1
nE = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"the matrix of the eigenvectors must be a quadratic matrix"
  | Ix1
nDim forall a. Ord a => a -> a -> Bool
<= Ix1
0 = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"the number of dimensions of the PCA is smaller than or zero"
  | Ix1
nDim forall a. Ord a => a -> a -> Bool
>= Ix1
nE = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"more than the possible amount of dimensions has been selected"
  | Ix1
mE forall a. Eq a => a -> a -> Bool
/= Ix1
mF = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"eigenvector matrix and feature matrix have mismatching dimensions"
  | Bool
otherwise = do
    Matrix r e
matA <- forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Source r e => Matrix r e -> Matrix D e
transpose forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
ix -> Sz ix -> Array r ix e -> m (Array D ix e)
extractM (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
0) (forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ Ix1
mE Ix1 -> Ix1 -> Ix2
:. Ix1
nDim) Matrix r e
eigenVecMat
    Matrix r e
pcaData <- Matrix r e
matA forall r e (m :: * -> *).
(Numeric r e, Manifest r e, MonadThrow m) =>
Matrix r e -> Matrix r e -> m (Matrix r e)
.><. Matrix r e
featureMat
    forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix r e
pcaData, Matrix r e
matA)
  where
    Sz (Ix1
mE :. Ix1
nE) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
eigenVecMat
    Sz (Ix1
mF :. Ix1
_nF) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
featureMat

-- | Performs a PCA on the feature matrix \(\mathbf{X}\) by solving the eigenproblem of the
-- covariance matrix. The function takes the feature matrix directly and perfoms the conversion
-- to mean deviation form, the calculation of the covariance matrix and the eigenvalue problem
-- automatically.
{-# SCC pca #-}
pca ::
  ( Numeric r Double,
    Manifest r Double,
    Load r Ix1 Double,
    Load r Ix2 Double,
    MonadThrow m
  ) =>
  -- | Dimensionalty after PCA transformation.
  Int ->
  -- | \(m \times n\) Feaute matrix \(\mathbf{X}\), with \(m\) different measurements (rows) in
  -- \(n\) different trials (columns).
  Matrix r Double ->
  m PCA
pca :: forall r (m :: * -> *).
(Numeric r Double, Manifest r Double, Load r Ix1 Double,
 Load r Ix2 Double, MonadThrow m) =>
Ix1 -> Matrix r Double -> m PCA
pca Ix1
dim Matrix r Double
x = do
  -- Calculate the mean deviation form of the feature matrix and the covariance matrix from it.
  let x' :: Matrix r Double
x' = forall e r.
(Ord e, Unbox e, Numeric r e, Fractional e, Manifest r e) =>
Array r Ix2 e -> Array r Ix2 e
normalise forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e.
(Source r e, Fractional e, Unbox e, Numeric r e, Manifest r e) =>
Matrix r e -> Matrix r e
meanDeviation forall a b. (a -> b) -> a -> b
$ Matrix r Double
x
      cov :: Matrix r Double
cov = forall r e.
(Numeric r e, Manifest r e, Fractional e) =>
Matrix r e -> Matrix r e
covariance Matrix r Double
x'

  -- Obtain eigenvalues and eigenvectors of the covariance matrix and sort them.
  (Vector U (Complex Double)
eigValsC :: Vector U (Complex Double), Matrix U (Complex Double)
eigVecsC :: Matrix U (Complex Double)) <- forall e r3 r1 r2 (m :: * -> *).
(Field e, Manifest r3 e, Manifest r1 (Complex Double),
 Manifest r2 (Complex Double), Load r1 Ix1 (Complex Double),
 Load r2 Ix1 (Complex Double), Load r3 Ix1 e, MonadThrow m) =>
Matrix r3 e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig Matrix r Double
cov
  let eigValsR :: Vector U Double
eigValsR = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map forall a. Complex a -> a
realPart forall a b. (a -> b) -> a -> b
$ Vector U (Complex Double)
eigValsC
      eigVecsR :: Matrix r Double
eigVecsR = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map forall a. Complex a -> a
realPart forall a b. (a -> b) -> a -> b
$ Matrix U (Complex Double)
eigVecsC
  (Vector U Double
eValS, Matrix r Double
eVecS) <- forall r2 e (m :: * -> *) r1.
(Load r2 Ix2 e, MonadThrow m, Source r1 e, Source r2 e,
 Manifest r1 e, Manifest r2 e, Unbox e, Ord e) =>
(Vector r1 e, Matrix r2 e) -> m (Vector r1 e, Matrix r2 e)
eigSort (Vector U Double
eigValsR, Matrix r Double
eigVecsR)

  -- Use the subset of the eigenvectors with the largest eigenvalues to transform the features in
  -- mean deviation form into the result matrix Y.
  (Matrix r Double
pcaData, Matrix r Double
matA) <- forall r e (m :: * -> *).
(Manifest r e, Numeric r e, MonadThrow m) =>
Ix1 -> Matrix r e -> Matrix r e -> m (Matrix r e, Matrix r e)
transformToPCABasis Ix1
dim Matrix r Double
eVecS Matrix r Double
x'

  -- Reconstuct the original data from lower dimensions and calculate the mean squared deviation.
  Matrix r Double
reconstructX <- (forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Source r e => Matrix r e -> Matrix D e
transpose forall a b. (a -> b) -> a -> b
$ Matrix r Double
matA) forall r e (m :: * -> *).
(Numeric r e, Manifest r e, MonadThrow m) =>
Matrix r e -> Matrix r e -> m (Matrix r e)
.><. Matrix r Double
pcaData
  Double
mse <- (forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
Massiv.sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (forall a. Floating a => a -> a -> a
** Double
2) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Matrix r Double
x' forall ix r e (m :: * -> *).
(Index ix, Numeric r e, MonadThrow m) =>
Array r ix e -> Array r ix e -> m (Array r ix e)
.-. Matrix r Double
reconstructX)

  -- For output give the eigenvalues and eigenvectors that were kept.
  Array D Ix1 Double
pcaEigenValues <- forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
ix -> Sz ix -> Array r ix e -> m (Array D ix e)
extractM Ix1
0 (forall ix. Index ix => ix -> Sz ix
Sz Ix1
dim) Vector U Double
eValS
  Array D Ix2 Double
pcaEigenVecs <- forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
ix -> Sz ix -> Array r ix e -> m (Array D ix e)
extractM (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
0) (forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ Ix1
m Ix1 -> Ix1 -> Ix2
:. Ix1
dim) Matrix r Double
eVecS

  -- Calculate the amount of behaviour that could be kept.
  let remaining :: Double
remaining = (forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
Massiv.sum Array D Ix1 Double
pcaEigenValues forall a. Fractional a => a -> a -> a
/ forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
Massiv.sum Vector U Double
eValS) forall a. Num a => a -> a -> a
* Double
100

  forall (m :: * -> *) a. Monad m => a -> m a
return
    PCA
      { $sel:x:PCA :: Matrix U Double
x = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
x,
        $sel:x':PCA :: Matrix U Double
x' = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
x',
        $sel:y:PCA :: Matrix U Double
y = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
pcaData,
        $sel:a:PCA :: Matrix U Double
a = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
matA,
        $sel:mse:PCA :: Double
mse = Double
mse,
        $sel:remaining:PCA :: Double
remaining = Double
remaining,
        $sel:allEigenValues:PCA :: Vector U Double
allEigenValues = Vector U Double
eValS,
        $sel:pcaEigenValues:PCA :: Vector U Double
pcaEigenValues = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix1 Double
pcaEigenValues,
        $sel:allEigenVecs:PCA :: Matrix U Double
allEigenVecs = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
eVecS,
        $sel:pcaEigenVecs:PCA :: Matrix U Double
pcaEigenVecs = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix2 Double
pcaEigenVecs
      }
  where
    Sz (Ix1
m :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r Double
x

----------------------------------------------------------------------------------------------------
-- Variance

-- | Subtract the mean value of all columns from the feature matrix. Brings the feature matrix to
-- mean deviation form.
{-# SCC meanDeviation #-}
meanDeviation ::
  ( Source r e,
    Fractional e,
    Unbox e,
    Numeric r e,
    Manifest r e
  ) =>
  Matrix r e ->
  Matrix r e
meanDeviation :: forall r e.
(Source r e, Fractional e, Unbox e, Numeric r e, Manifest r e) =>
Matrix r e -> Matrix r e
meanDeviation Matrix r e
mat = Matrix r e
mat forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix2 e
meanMat
  where
    Sz (Ix1
_ :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Matrix r e
mat
    featueMean :: Array D Ix1 e
featueMean = forall ix r e a.
(Index (Lower ix), Index ix, Source r e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
Massiv.foldlInner forall a. Num a => a -> a -> a
(+) e
0 Matrix r e
mat forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* (e
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
n)
    meanMat :: Array D Ix2 e
meanMat = forall r ix a b.
(Index ix, Index (Lower ix), Manifest r a) =>
Sz Ix1 -> (a -> Ix1 -> b) -> Array r (Lower ix) a -> Array D ix b
expandInner (forall ix. Index ix => ix -> Sz ix
Sz Ix1
n) forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall a b. (a -> b) -> a -> b
$ Array D Ix1 e
featueMean

-- | Obtains the covariance matrix \(\mathbf{C_X}\) from the feature matrix \(\mathbf{X}\).
-- \[
--   \mathbf{C_X} \equiv \frac{1}{n - 1} \mathbf{X} \mathbf{X}^T
-- \]
-- where \(n\) is the number of columns in the matrix.
--
-- The feature matrix should be in mean deviation form, see 'meanDeviation'.
{-# SCC covariance #-}
covariance :: (Numeric r e, Manifest r e, Fractional e) => Matrix r e -> Matrix r e
covariance :: forall r e.
(Numeric r e, Manifest r e, Fractional e) =>
Matrix r e -> Matrix r e
covariance Matrix r e
x = (e
1 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
n forall a. Num a => a -> a -> a
- e
1)) forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (Matrix r e
x forall r e.
(Numeric r e, Manifest r e) =>
Matrix r e -> Matrix r e -> Matrix r e
!><! (forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Source r e => Matrix r e -> Matrix D e
transpose forall a b. (a -> b) -> a -> b
$ Matrix r e
x))
  where
    Sz (Ix1
_ :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
x

-- | Normalise each value so that the maximum absolute value in each row becomes one.
normalise ::
  ( Ord e,
    Unbox e,
    Numeric r e,
    Fractional e,
    Manifest r e
  ) =>
  Array r Ix2 e ->
  Array r Ix2 e
normalise :: forall e r.
(Ord e, Unbox e, Numeric r e, Fractional e, Manifest r e) =>
Array r Ix2 e -> Array r Ix2 e
normalise Array r Ix2 e
mat =
  let absMat :: Array D Ix2 e
absMat = forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map forall a. Num a => a -> a
abs Array r Ix2 e
mat
      maxPerRow :: Array U Ix1 e
maxPerRow = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e a.
(Index (Lower ix), Index ix, Source r e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
foldlInner forall a. Ord a => a -> a -> a
max e
0 forall a b. (a -> b) -> a -> b
$ Array D Ix2 e
absMat
      divMat :: Array r Ix2 e
divMat = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (e
1 forall a. Fractional a => a -> a -> a
/) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix a b.
(Index ix, Index (Lower ix), Manifest r a) =>
Sz Ix1 -> (a -> Ix1 -> b) -> Array r (Lower ix) a -> Array D ix b
expandInner @U @Ix2 (forall ix. Index ix => ix -> Sz ix
Sz Ix1
n) forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Array U Ix1 e
maxPerRow
   in Array r Ix2 e
divMat forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array r Ix2 e
mat
  where
    Sz (Ix1
_ :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Array r Ix2 e
mat

----------------------------------------------------------------------------------------------------
-- Distance Measures

-- | Distance matrix generator functions.
type DistFn r e = Matrix r e -> Matrix r e

-- | Builds the distance measures in a permutation matrix/distance matrix.
buildDistMat ::
  (Manifest r e) =>
  -- | Zip function to combine the elements of vectors \(\mathbf{a}\) \(\mathbf{b}\). Usually @(-)@.
  -- \( f(\mathbf{a}_i, \mathbf{b}_i) = \mathbf{c} \)
  (e -> e -> a) ->
  -- | Fold the vector \(\mathbf{c}\) elementwise to a distance \(d\).
  (a -> a -> a) ->
  -- | Accumulator of the fold function.
  a ->
  -- | \(m \times n\) matrix, with \(n\) \(m\)-dimensional points (column vectors of the matrix).
  Matrix r e ->
  -- | Resulting distance matrix.
  Matrix D a
buildDistMat :: forall r e a.
Manifest r e =>
(e -> e -> a) -> (a -> a -> a) -> a -> Matrix r e -> Matrix D a
buildDistMat e -> e -> a
zipFn a -> a -> a
foldFn a
acc Matrix r e
mat =
  let a :: Array D Ix3 e
a = forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeOuter @D @Ix3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix a b.
(Index ix, Index (Lower ix), Manifest r a) =>
Sz Ix1 -> (a -> Ix1 -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter (forall ix. Index ix => ix -> Sz ix
Sz Ix1
n) forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Matrix r e
mat
      b :: Array D Ix3 e
b = forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeInner Array D Ix3 e
a
      ab :: Array D Ix3 a
ab = forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
Massiv.zipWith e -> e -> a
zipFn Array D Ix3 e
a Array D Ix3 e
b
      d :: Array D (Lower Ix3) a
d = forall ix r e a.
(Index (Lower ix), Index ix, Source r e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
foldlInner a -> a -> a
foldFn a
acc Array D Ix3 a
ab
   in Array D (Lower Ix3) a
d
  where
    Sz (Ix1
_ :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
mat

-- | The \(L_p\) norm between two vectors. Generalisation of Manhattan and Euclidean distances.
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \left( \sum \limits_{i=1}^n \lvert \mathbf{a}_i - \mathbf{b}_i \rvert ^p \right) ^ \frac{1}{p}
-- \]
{-# SCC lpNorm #-}
lpNorm :: (Manifest r e, Floating e) => Int -> DistFn r e
lpNorm :: forall r e. (Manifest r e, Floating e) => Ix1 -> DistFn r e
lpNorm Ix1
p = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e a.
Manifest r e =>
(e -> e -> a) -> (a -> a -> a) -> a -> Matrix r e -> Matrix D a
buildDistMat e -> e -> e
zipFn e -> e -> e
foldFn e
0
  where
    zipFn :: e -> e -> e
zipFn e
a e
b = (forall a b. (Num a, Integral b) => a -> b -> a
^ Ix1
p) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ e
a forall a. Num a => a -> a -> a
- e
b
    foldFn :: e -> e -> e
foldFn e
a e
b = (forall a. Floating a => a -> a -> a
** (e
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
p)) forall a b. (a -> b) -> a -> b
$ e
a forall a. Num a => a -> a -> a
+ e
b

-- | The Manhattan distance between two vectors. Specialisation of the \(L_p\) norm for \(p = 1\).
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sum \limits_{i=1}^n \lvert \mathbf{a}_i - \mathbf{b}_i \rvert
-- \]
{-# SCC manhattan #-}
manhattan :: (Manifest r e, Floating e) => DistFn r e
manhattan :: forall r e. (Manifest r e, Floating e) => DistFn r e
manhattan = forall r e. (Manifest r e, Floating e) => Ix1 -> DistFn r e
lpNorm Ix1
1

-- | The Euclidean distance between two vectors. Specialisation of the \(L_p\) norm for \(p = 2\).
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sqrt{\sum \limits_{i=1}^n (\mathbf{a}_i - \mathbf{b}_i)^2}
-- \]
{-# SCC euclidean #-}
euclidean :: (Manifest r e, Floating e) => DistFn r e
euclidean :: forall r e. (Manifest r e, Floating e) => DistFn r e
euclidean = forall r e. (Manifest r e, Floating e) => Ix1 -> DistFn r e
lpNorm Ix1
2

-- | Mahalanobis distance between points. Suitable for non correlated axes.
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sqrt{(\mathbf{a} - \mathbf{b})^T \mathbf{S}^{-1} (\mathbf{a} - \mathbf{b})}
-- \]
-- where \(\mathbf{S}\) is the covariance matrix.
{-# SCC mahalanobis #-}
mahalanobis :: (Unbox e, Numeric r e, Manifest r e, LA.Field e, Load r Ix1 e) => DistFn r e
mahalanobis :: forall e r.
(Unbox e, Numeric r e, Manifest r e, Field e, Load r Ix1 e) =>
DistFn r e
mahalanobis Matrix r e
points =
  let a :: Array D Ix3 e
a = forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeOuter @D @Ix3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix a b.
(Index ix, Index (Lower ix), Manifest r a) =>
Sz Ix1 -> (a -> Ix1 -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter (forall ix. Index ix => ix -> Sz ix
Sz Ix1
n) forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Matrix r e
points
      b :: Array D Ix3 e
b = forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeInner Array D Ix3 e
a
      abDiff :: Array U Ix3 e
abDiff = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall a b. (a -> b) -> a -> b
$ Array D Ix3 e
a forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! Array D Ix3 e
b
      ixArray :: Array U Ix2 Ix2
ixArray = forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray @U @Ix2 @Ix2 Comp
Par (forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ Ix1
n Ix1 -> Ix1 -> Ix2
:. Ix1
n) forall a. a -> a
id
      distMat :: Array D Ix2 e
distMat =
        forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map
          ( \(Ix1
x :. Ix1
y) ->
              let ab :: Array U Ix1 e
ab = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall a b. (a -> b) -> a -> b
$ Array U Ix3 e
abDiff forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Ix1 -> Array r (Lower ix) e
!> Ix1
x forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Ix1 -> Array r (Lower ix) e
!> Ix1
y
               in Array U Ix1 e
ab forall r e.
(Numeric r e, Manifest r e) =>
Vector r e -> Matrix r e -> Vector r e
><! Matrix U e
covInv forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! Array U Ix1 e
ab
          )
          Array U Ix2 Ix2
ixArray
   in forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map forall a. Floating a => a -> a
sqrt forall a b. (a -> b) -> a -> b
$ Array D Ix2 e
distMat
  where
    Sz (Ix1
_ :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
points
    cov :: Matrix r e
cov = forall r e.
(Numeric r e, Manifest r e, Fractional e) =>
Matrix r e -> Matrix r e
covariance forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e.
(Source r e, Fractional e, Unbox e, Numeric r e, Manifest r e) =>
Matrix r e -> Matrix r e
meanDeviation forall a b. (a -> b) -> a -> b
$ Matrix r e
points
    covInv :: Matrix U e
covInv = forall r e.
(Manifest r e, Load r Ix1 e, Element e) =>
Matrix e -> Matrix r e
matH2M forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Field t => Matrix t -> Matrix t
LA.inv forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e r.
(Element e, Manifest r e, Load r Ix1 e) =>
Matrix r e -> Matrix e
matM2H forall a b. (a -> b) -> a -> b
$ Matrix r e
cov

----------------------------------------------------------------------------------------------------
-- DBScan

-- | Exception for invalid search distances.
newtype DistanceInvalidException e = DistanceInvalidException e deriving (Ix1 -> DistanceInvalidException e -> ShowS
forall e. Show e => Ix1 -> DistanceInvalidException e -> ShowS
forall e. Show e => [DistanceInvalidException e] -> ShowS
forall e. Show e => DistanceInvalidException e -> String
forall a.
(Ix1 -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DistanceInvalidException e] -> ShowS
$cshowList :: forall e. Show e => [DistanceInvalidException e] -> ShowS
show :: DistanceInvalidException e -> String
$cshow :: forall e. Show e => DistanceInvalidException e -> String
showsPrec :: Ix1 -> DistanceInvalidException e -> ShowS
$cshowsPrec :: forall e. Show e => Ix1 -> DistanceInvalidException e -> ShowS
Show, DistanceInvalidException e -> DistanceInvalidException e -> Bool
forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DistanceInvalidException e -> DistanceInvalidException e -> Bool
$c/= :: forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
== :: DistanceInvalidException e -> DistanceInvalidException e -> Bool
$c== :: forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
Eq)

instance (Typeable e, Show e) => Exception (DistanceInvalidException e)

-- | Representation of clusters.
type Clusters = Vector B IntSet

-- | DBScan algorithm.
{-# SCC dbscan #-}
dbscan ::
  ( MonadThrow m,
    Ord e,
    Num e,
    Typeable e,
    Show e,
    Source r e
  ) =>
  -- | Distance measure to build the distance matrix of all points.
  DistFn r e ->
  -- | Minimal number of members in a cluster.
  Int ->
  -- | Search radius \(\epsilon\)
  e ->
  -- | \(n\) \(m\)-dimensional data points as column vectors of a \(m \times n\) matrix.
  Matrix r e ->
  -- | Resulting clusters.
  m Clusters
dbscan :: forall (m :: * -> *) e r.
(MonadThrow m, Ord e, Num e, Typeable e, Show e, Source r e) =>
DistFn r e -> Ix1 -> e -> Matrix r e -> m Clusters
dbscan DistFn r e
distFn Ix1
nPoints e
epsilon Matrix r e
points
  | forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Matrix r e
points = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (forall ix. Index ix => ix -> Sz ix
Sz Ix1
0 :: Sz1)
  | Ix1
nPoints forall a. Ord a => a -> a -> Bool
< Ix1
1 = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeNegativeException (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nPoints)
  | e
epsilon forall a. Ord a => a -> a -> Bool
<= e
0 = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall e. e -> DistanceInvalidException e
DistanceInvalidException e
epsilon
  | Bool
otherwise =
    let pointNeighbours :: Array D (Lower Ix2) IntSet
pointNeighbours = forall ix r e a.
(Index (Lower ix), Index ix, Source r e) =>
(ix -> a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
ifoldlInner Ix2 -> IntSet -> e -> IntSet
collectNeighbours forall a. Monoid a => a
mempty Matrix r e
distMat
        allClusters :: Clusters
allClusters = Clusters -> Clusters
joinOverlapping forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B forall a b. (a -> b) -> a -> b
$ Array D (Lower Ix2) IntSet
pointNeighbours
        largeClusters :: Vector DS IntSet
largeClusters = forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (\IntSet
s -> IntSet -> Ix1
IntSet.size IntSet
s forall a. Ord a => a -> a -> Bool
>= Ix1
nPoints) Clusters
allClusters
     in forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Vector DS IntSet
largeClusters
  where
    -- The distance matrix in the measure chosen by the distance function.
    distMat :: Matrix r e
distMat = DistFn r e
distFn Matrix r e
points

    -- Function to collect the neighbours of a point within the search radius epsilon.
    {-# SCC collectNeighbours #-}
    collectNeighbours :: Ix2 -> IntSet -> e -> IntSet
collectNeighbours (Ix1
_ :. Ix1
n) IntSet
acc e
d = if e
d forall a. Ord a => a -> a -> Bool
<= e
epsilon then Ix1 -> IntSet -> IntSet
IntSet.insert Ix1
n IntSet
acc else IntSet
acc

    -- Construct the overlap matrix of all clusters.
    compareSets :: (IntSet -> IntSet -> Bool) -> Vector B IntSet -> Matrix D Bool
    compareSets :: (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets IntSet -> IntSet -> Bool
fn Clusters
clVec =
      let a :: Array D Ix2 IntSet
a = forall r ix a b.
(Index ix, Index (Lower ix), Manifest r a) =>
Sz Ix1 -> (a -> Ix1 -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter Sz Ix1
sz forall a b. a -> b -> a
const Clusters
clVec
          b :: Array D Ix2 IntSet
b = forall r e. Source r e => Matrix r e -> Matrix D e
transpose Array D Ix2 IntSet
a
          compareMat :: Matrix D Bool
compareMat = forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
Massiv.zipWith IntSet -> IntSet -> Bool
fn Array D Ix2 IntSet
a Array D Ix2 IntSet
b
       in Matrix D Bool
compareMat
      where
        sz :: Sz Ix1
sz = forall r ix e. Size r => Array r ix e -> Sz ix
size Clusters
clVec

    -- Overlap matrix. Checks if two sets have any overlap. Sets do overlap with themself.
    overlap :: Vector B IntSet -> Matrix D Bool
    overlap :: Clusters -> Matrix D Bool
overlap = (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets (\IntSet
a IntSet
b -> Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ IntSet -> IntSet -> Bool
IntSet.disjoint IntSet
a IntSet
b)

    -- Check if any set overlaps wiht **any** other set.
    anyOtherOverlap :: Vector B IntSet -> Bool
    anyOtherOverlap :: Clusters -> Bool
anyOtherOverlap = forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
Massiv.any (forall a. Eq a => a -> a -> Bool
== Bool
True) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
imap (\(Ix1
m :. Ix1
n) Bool
v -> if Ix1
m forall a. Eq a => a -> a -> Bool
== Ix1
n then Bool
False else Bool
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
overlap

    -- Check if two sets are identical. Sets are identical to themself.
    same :: Vector B IntSet -> Matrix D Bool
    same :: Clusters -> Matrix D Bool
same = (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets forall a. Eq a => a -> a -> Bool
(==)

    -- Join all overlapping clusters recursively.
    {-# SCC joinOverlapping #-}
    joinOverlapping :: Vector B IntSet -> Vector B IntSet
    joinOverlapping :: Clusters -> Clusters
joinOverlapping Clusters
clVec =
      let -- The overlap matrix of the clusters.
          ovlpMat :: Array U Ix2 Bool
ovlpMat = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
overlap forall a b. (a -> b) -> a -> b
$ Clusters
clVec
          anyOvlp :: Bool
anyOvlp = Clusters -> Bool
anyOtherOverlap Clusters
clVec

          -- Join all sets that have overlap but keep them redundantly (no reduction of the amount
          -- of clusters).
          joined :: Array D (Lower Ix2) IntSet
joined =
            forall ix r e a.
(Index (Lower ix), Index ix, Source r e) =>
(ix -> a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
ifoldlInner
              (\(Ix1
_ :. Ix1
n) IntSet
acc Bool
ovlp -> if Bool
ovlp then (Clusters
clVec forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
n) forall a. Semigroup a => a -> a -> a
<> IntSet
acc else IntSet
acc)
              forall a. Monoid a => a
mempty
              Array U Ix2 Bool
ovlpMat

          -- Find all sets at different indices that are the same. This is an upper triangular
          -- matrix with the main diagonal being False.
          sameMat :: Array U Ix2 Bool
sameMat =
            forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
imap (\(Ix1
m :. Ix1
n) Bool
v -> if Ix1
m forall a. Ord a => a -> a -> Bool
>= Ix1
n then Bool
False else Bool
v)
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
same
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B
              forall a b. (a -> b) -> a -> b
$ Array D (Lower Ix2) IntSet
joined

          -- Remove all sets that are redundant. Redundancy is checked by two criteria:
          --   1. Is this cluster the same set of points as any other cluster? If yes, it is
          --      redundant.
          --   2. Is this cluster isolated it is not redundant.
          nonRed :: Clusters
nonRed =
            forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e.
Stream r ix e =>
(ix -> e -> Bool) -> Array r ix e -> Vector DS e
sifilter
                ( \Ix1
ix IntSet
_ ->
                    let sameAsAnyOther :: Bool
sameAsAnyOther = forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
Massiv.any (forall a. Eq a => a -> a -> Bool
== Bool
True) forall a b. (a -> b) -> a -> b
$ Array U Ix2 Bool
sameMat forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Ix1 -> Array r (Lower ix) e
!> Ix1
ix
                     in Bool -> Bool
not Bool
sameAsAnyOther
                )
              forall a b. (a -> b) -> a -> b
$ Array D (Lower Ix2) IntSet
joined
       in if Bool
anyOvlp then Clusters -> Clusters
joinOverlapping Clusters
nonRed else Clusters
clVec

----------------------------------------------------------------------------------------------------
-- Hierarchical Cluster Analysis

-- | Nodes of a dendrogram.
data DendroNode e = DendroNode
  { forall e. DendroNode e -> e
distance :: e,
    forall e. DendroNode e -> IntSet
cluster :: IntSet
  }
  deriving (DendroNode e -> DendroNode e -> Bool
forall e. Eq e => DendroNode e -> DendroNode e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DendroNode e -> DendroNode e -> Bool
$c/= :: forall e. Eq e => DendroNode e -> DendroNode e -> Bool
== :: DendroNode e -> DendroNode e -> Bool
$c== :: forall e. Eq e => DendroNode e -> DendroNode e -> Bool
Eq, Ix1 -> DendroNode e -> ShowS
forall e. Show e => Ix1 -> DendroNode e -> ShowS
forall e. Show e => [DendroNode e] -> ShowS
forall e. Show e => DendroNode e -> String
forall a.
(Ix1 -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DendroNode e] -> ShowS
$cshowList :: forall e. Show e => [DendroNode e] -> ShowS
show :: DendroNode e -> String
$cshow :: forall e. Show e => DendroNode e -> String
showsPrec :: Ix1 -> DendroNode e -> ShowS
$cshowsPrec :: forall e. Show e => Ix1 -> DendroNode e -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (DendroNode e) x -> DendroNode e
forall e x. DendroNode e -> Rep (DendroNode e) x
$cto :: forall e x. Rep (DendroNode e) x -> DendroNode e
$cfrom :: forall e x. DendroNode e -> Rep (DendroNode e) x
Generic)

instance (FromJSON e) => FromJSON (DendroNode e)

instance (ToJSON e) => ToJSON (DendroNode e)

-- | A dendrogram as a binary tree.
newtype Dendrogram e = Dendrogram {forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro :: BinTree (DendroNode e)}
  deriving (Ix1 -> Dendrogram e -> ShowS
forall e. Show e => Ix1 -> Dendrogram e -> ShowS
forall e. Show e => [Dendrogram e] -> ShowS
forall e. Show e => Dendrogram e -> String
forall a.
(Ix1 -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dendrogram e] -> ShowS
$cshowList :: forall e. Show e => [Dendrogram e] -> ShowS
show :: Dendrogram e -> String
$cshow :: forall e. Show e => Dendrogram e -> String
showsPrec :: Ix1 -> Dendrogram e -> ShowS
$cshowsPrec :: forall e. Show e => Ix1 -> Dendrogram e -> ShowS
Show, Dendrogram e -> Dendrogram e -> Bool
forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dendrogram e -> Dendrogram e -> Bool
$c/= :: forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
== :: Dendrogram e -> Dendrogram e -> Bool
$c== :: forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
Eq, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (Dendrogram e) x -> Dendrogram e
forall e x. Dendrogram e -> Rep (Dendrogram e) x
$cto :: forall e x. Rep (Dendrogram e) x -> Dendrogram e
$cfrom :: forall e x. Dendrogram e -> Rep (Dendrogram e) x
Generic)

instance ToJSON e => ToJSON (Dendrogram e)

instance FromJSON e => FromJSON (Dendrogram e)

-- | An accumulator to finally build a dendrogram by a bottom-up algorithm. Not to be exposed in the
-- API.
type DendroAcc e = Vector B (Dendrogram e)

-- | Manifest version of the dendrogram accumulator.
type DendroAccM m e = MArray (PrimState m) B Ix1 (Dendrogram e)

-- | Cut a 'Dendrogram' at a given distance and obtain all clusters from it.
cutDendroAt :: Ord e => Dendrogram e -> e -> Clusters
cutDendroAt :: forall e. Ord e => Dendrogram e -> e -> Clusters
cutDendroAt Dendrogram e
dendro e
dist =
  let nodes :: Vector DL (DendroNode e)
nodes = forall a. (a -> Bool) -> BinTree a -> Vector DL a
takeLeafyBranchesWhile (\DendroNode {e
distance :: e
$sel:distance:DendroNode :: forall e. DendroNode e -> e
distance} -> e
distance forall a. Ord a => a -> a -> Bool
>= e
dist) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro forall a b. (a -> b) -> a -> b
$ Dendrogram e
dendro
      clusters :: Clusters
clusters = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map forall e. DendroNode e -> IntSet
cluster forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B forall a b. (a -> b) -> a -> b
$ Vector DL (DendroNode e)
nodes
   in Clusters
clusters

-- | A strategy/distance measure for clusters.
data JoinStrat e
  = SingleLinkage
  | CompleteLinkage
  | Median
  | UPGMA
  | WPGMA
  | Centroid
  | Ward
  | LWFB e
  | LW e e e e
  deriving (JoinStrat e -> JoinStrat e -> Bool
forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JoinStrat e -> JoinStrat e -> Bool
$c/= :: forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
== :: JoinStrat e -> JoinStrat e -> Bool
$c== :: forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
Eq, Ix1 -> JoinStrat e -> ShowS
forall e. Show e => Ix1 -> JoinStrat e -> ShowS
forall e. Show e => [JoinStrat e] -> ShowS
forall e. Show e => JoinStrat e -> String
forall a.
(Ix1 -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JoinStrat e] -> ShowS
$cshowList :: forall e. Show e => [JoinStrat e] -> ShowS
show :: JoinStrat e -> String
$cshow :: forall e. Show e => JoinStrat e -> String
showsPrec :: Ix1 -> JoinStrat e -> ShowS
$cshowsPrec :: forall e. Show e => Ix1 -> JoinStrat e -> ShowS
Show)

-- | Lance Williams formula to update distances.
{-# SCC lanceWilliams #-}
lanceWilliams ::
  Fractional e =>
  -- | How to calculate distance between clusters of points.
  JoinStrat e ->
  -- | Number of points in cluster \(A\).
  Int ->
  -- | Number of points in cluster \(B\)
  Int ->
  -- | Number of points in cluster \(C\)
  Int ->
  -- | \(d(A, B)\)
  e ->
  -- | \(d(A, C)\)
  e ->
  -- | \(d(B, C)\)
  e ->
  -- | Updated distance \(D \(A \cup B, C\)
  e
lanceWilliams :: forall e.
Fractional e =>
JoinStrat e -> Ix1 -> Ix1 -> Ix1 -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Ix1
nA Ix1
nB Ix1
nC e
dAB e
dAC e
dBC = e
alpha1 forall a. Num a => a -> a -> a
* e
dAC forall a. Num a => a -> a -> a
+ e
alpha2 forall a. Num a => a -> a -> a
* e
dBC forall a. Num a => a -> a -> a
+ e
beta forall a. Num a => a -> a -> a
* e
dAB forall a. Num a => a -> a -> a
+ e
gamma forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs (e
dAC forall a. Num a => a -> a -> a
- e
dBC)
  where
    nA' :: e
nA' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
nA
    nB' :: e
nB' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
nB
    nC' :: e
nC' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
nC
    (e
alpha1, e
alpha2, e
beta, e
gamma) = case JoinStrat e
js of
      JoinStrat e
SingleLinkage -> (e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
0, - e
1 forall a. Fractional a => a -> a -> a
/ e
2)
      JoinStrat e
CompleteLinkage -> (e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
0, e
1 forall a. Fractional a => a -> a -> a
/ e
2)
      JoinStrat e
Median -> (e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
1 forall a. Fractional a => a -> a -> a
/ e
2, - e
1 forall a. Fractional a => a -> a -> a
/ e
4, e
0)
      JoinStrat e
UPGMA -> (e
nA' forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB'), e
nB' forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB'), e
0, e
0)
      JoinStrat e
WPGMA -> (e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
1 forall a. Fractional a => a -> a -> a
/ e
2, e
0, e
0)
      JoinStrat e
Centroid -> (e
nA' forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB'), e
nB' forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB'), - (e
nA' forall a. Num a => a -> a -> a
* e
nB') forall a. Fractional a => a -> a -> a
/ ((e
nA' forall a. Num a => a -> a -> a
+ e
nB') forall a b. (Num a, Integral b) => a -> b -> a
^ (Ix1
2 :: Int)), e
0)
      JoinStrat e
Ward -> ((e
nA' forall a. Num a => a -> a -> a
+ e
nC') forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB' forall a. Num a => a -> a -> a
+ e
nC'), (e
nA' forall a. Num a => a -> a -> a
+ e
nC') forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB' forall a. Num a => a -> a -> a
+ e
nC'), - (e
nA' forall a. Num a => a -> a -> a
+ e
nC') forall a. Fractional a => a -> a -> a
/ (e
nA' forall a. Num a => a -> a -> a
+ e
nB' forall a. Num a => a -> a -> a
+ e
nC'), e
0)
      LWFB e
b -> ((e
1 forall a. Num a => a -> a -> a
- e
b) forall a. Fractional a => a -> a -> a
/ e
2, (e
1 forall a. Num a => a -> a -> a
- e
b) forall a. Fractional a => a -> a -> a
/ e
2, e
b, e
0)
      LW e
a e
b e
c e
d -> (e
a, e
b, e
c, e
d)

----------------------------------------------------------------------------------------------------
-- Müllner Generic Hierarchical Clustering

-- | A neighbourlist. At index @i@ of the vector it contains a tuple with the minimal distance of
-- this cluster to any other cluster and the index of the other cluster.
type Neighbourlist r e = Vector r (e, Ix1)

-- | A distance matrix.
type DistanceMatrix r e = Matrix r e

-- | Performance improved hierarchical clustering algorithm. @GENERIC_LINKAGE@ from figure 3,
-- <https://arxiv.org/pdf/1109.2378.pdf>.
{-# SCC hca #-}
hca ::
  ( MonadThrow m,
    Manifest r e,
    Manifest r (e, Ix1),
    Load r Ix1 e,
    Ord e,
    Unbox e,
    Fractional e
  ) =>
  DistFn r e ->
  JoinStrat e ->
  Matrix r e ->
  m (Dendrogram e)
hca :: forall (m :: * -> *) r e.
(MonadThrow m, Manifest r e, Manifest r (e, Ix1), Load r Ix1 e,
 Ord e, Unbox e, Fractional e) =>
DistFn r e -> JoinStrat e -> Matrix r e -> m (Dendrogram e)
hca DistFn r e
distFn JoinStrat e
joinStrat Matrix r e
points
  | forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
Massiv.isEmpty Matrix r e
points = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nPoints)
  | Bool
otherwise = do
    let -- The distance matrix from the points.
        distMat :: Matrix r e
distMat = DistFn r e
distFn Matrix r e
points

    -- Initial vector of nearest neighbour to each point.
    Vector r (e, Ix1)
nNghbr <- forall (m :: * -> *) r e.
(MonadThrow m, Manifest r e, Manifest r (e, Ix1), Load r Ix1 e,
 Ord e, Unbox e) =>
Matrix r e -> m (Vector r (e, Ix1))
nearestNeighbours Matrix r e
distMat

    let -- Initial priority queue of points. Has the minimum distance of all points.
        pq :: HashPSQ Ix1 e Ix1
pq = forall k p v.
(Hashable k, Ord k, Ord p) =>
[(k, p, v)] -> HashPSQ k p v
PQ.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
Massiv.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
Massiv.imap (\Ix1
k (e
d, Ix1
n) -> (Ix1
k, e
d, Ix1
n)) forall a b. (a -> b) -> a -> b
$ Vector r (e, Ix1)
nNghbr
        -- Set of points not joined yet. Initially all points.
        s :: IntSet
s = [Ix1] -> IntSet
IntSet.fromDistinctAscList [Ix1
0 .. Ix1
nPoints forall a. Num a => a -> a -> a
- Ix1
1]
        -- Initial dendrogram accumulator. The vector of all points as their own cluster.
        dendroAcc :: Array B Ix1 (Dendrogram e)
dendroAcc =
          forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray @B @Ix1
            Comp
Par
            (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nPoints)
            (\Ix1
p -> forall e. BinTree (DendroNode e) -> Dendrogram e
Dendrogram forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. e -> BinTree e
Leaf forall a b. (a -> b) -> a -> b
$ DendroNode {$sel:distance:DendroNode :: e
distance = e
0, $sel:cluster:DendroNode :: IntSet
cluster = Ix1 -> IntSet
IntSet.singleton Ix1
p})

    MArray RealWorld r Ix2 e
distMatM <- forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw forall a b. (a -> b) -> a -> b
$ Matrix r e
distMat
    MArray RealWorld r Ix1 (e, Ix1)
nNghbrM <- forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw forall a b. (a -> b) -> a -> b
$ Vector r (e, Ix1)
nNghbr
    MArray RealWorld B Ix1 (Dendrogram e)
dendroAccM <- forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw forall a b. (a -> b) -> a -> b
$ Array B Ix1 (Dendrogram e)
dendroAcc

    forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r e, Manifest r (e, Ix1),
 Shape r Ix1, Fractional e, Ord e) =>
JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray RealWorld r Ix2 e
distMatM MArray RealWorld r Ix1 (e, Ix1)
nNghbrM HashPSQ Ix1 e Ix1
pq IntSet
s MArray RealWorld B Ix1 (Dendrogram e)
dendroAccM
  where
    Sz (Ix1
_mFeatures :. Ix1
nPoints) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
points

-- | Agglomerative clustering by the improved generic linkage algorithm. This is the main loop
-- recursion L 10-43.
{-# SCC agglomerate #-}
agglomerate ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    PrimState m ~ RealWorld,
    Manifest r e,
    -- OuterSlice r Ix2 e,
    -- Manifest (R r) Ix1 e,
    Manifest r (e, Ix1),
    Shape r Ix1,
    Fractional e,
    Ord e
  ) =>
  -- | Join strategy for clusters and therefore how to calculate cluster-cluster distances.
  JoinStrat e ->
  -- | Distance matrix.
  MArray (PrimState m) r Ix2 e ->
  -- | List of nearest neighbours for each point.
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  -- | Priority queue with the distances as priorities and the cluster index as keys.
  PQ.HashPSQ Ix1 e Ix1 ->
  -- | A set \(S\), that keeps track which clusters have already been joined.
  IntSet ->
  -- | Accumulator of the dendrogram. Should collapse to a singleton vector.
  DendroAccM m e ->
  -- | The final dendrogram, after all clusters have been joined.
  m (Dendrogram e)
agglomerate :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r e, Manifest r (e, Ix1),
 Shape r Ix1, Fractional e, Ord e) =>
JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq IntSet
s DendroAccM m e
dendroAcc
  | IntSet -> Bool
IntSet.null IntSet
s = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"No clusters left. This must never happen."
  | Bool
otherwise = do
    -- Obtain candidates for the two clusters to join and the minimal distance in the priority queue.
    (Ix1, Ix1, e)
candidates <- forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, Manifest r (e, Ix1), Ord e) =>
MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1 -> m (Ix1, Ix1, e)
getJoinCandidates MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq

    -- If the distance between a b is not the minimal distance that the priority queue has found, the
    -- neighbour list must be wrong and recalculated.
    (Ix1
a, Ix1
b, e
delta, MArray RealWorld r Ix1 (e, Ix1)
nNghbrU1, HashPSQ Ix1 e Ix1
pqU1) <- forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r (e, Ix1), Manifest r e,
 Shape r Ix1, Ord e) =>
(Ix1, Ix1, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (Ix1, Ix1, e, MArray (PrimState m) r Ix1 (e, Ix1),
      HashPSQ Ix1 e Ix1)
recalculateNghbr (Ix1, Ix1, e)
candidates IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq

    -- Remove the minimal element from the priority queue and join clusters a and b. The cluster
    -- accumulator is reduced in its size: a is removed and b is updated with the joined cluster.
    (IntSet
newS, HashPSQ Ix1 e Ix1
pqU2, MArray RealWorld B Ix1 (Dendrogram e)
newAcc) <- forall (m :: * -> *) e.
(MonadThrow m, PrimMonad m, Ord e) =>
Ix1
-> Ix1
-> e
-> IntSet
-> HashPSQ Ix1 e Ix1
-> DendroAccM m e
-> m (IntSet, HashPSQ Ix1 e Ix1, DendroAccM m e)
joinClusters Ix1
a Ix1
b e
delta IntSet
s HashPSQ Ix1 e Ix1
pqU1 DendroAccM m e
dendroAcc

    -- Update the distance matrix in the row and column of b but not at (b,b) and not at (a,b) and
    -- (b,a).
    MArray RealWorld r Ix2 e
newDistMat <- forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m, Manifest r e,
 Fractional e) =>
JoinStrat e
-> Ix1
-> Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> DendroAccM m e
-> m (MArray (PrimState m) r Ix2 e)
updateDistMat JoinStrat e
joinStrat Ix1
a Ix1
b IntSet
newS MArray (PrimState m) r Ix2 e
distMat MArray RealWorld B Ix1 (Dendrogram e)
newAcc

    -- Redirect neighbours to b, if they previously pointed to a.
    MArray RealWorld r Ix1 (e, Ix1)
nNghbrU2 <- forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m, Manifest r (e, Ix1),
 Manifest r e) =>
Ix1
-> Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> m (MArray (PrimState m) r Ix1 (e, Ix1))
redirectNeighbours Ix1
a Ix1
b IntSet
newS MArray RealWorld r Ix2 e
newDistMat MArray RealWorld r Ix1 (e, Ix1)
nNghbrU1

    -- Preserve a lower bound in priority queue and update the nearest neighbour list.
    (MArray RealWorld r Ix1 (e, Ix1)
nNghbrU3, HashPSQ Ix1 e Ix1
pqU3) <- forall (m :: * -> *) r e.
(MonadThrow m, MonadUnliftIO m, PrimMonad m, Manifest r e,
 Manifest r (e, Ix1), Ord e) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (MArray (PrimState m) r Ix1 (e, Ix1), HashPSQ Ix1 e Ix1)
updateWithNewBDists Ix1
b IntSet
newS MArray RealWorld r Ix2 e
newDistMat MArray RealWorld r Ix1 (e, Ix1)
nNghbrU2 HashPSQ Ix1 e Ix1
pqU2

    -- Update the neighbourlist and priority queue with the new distances to b.
    (MArray RealWorld r Ix1 (e, Ix1)
newNNghbr, HashPSQ Ix1 e Ix1
newPQ) <- forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, RealWorld ~ PrimState m,
 MonadUnliftIO m, Manifest r (e, Ix1), Manifest r e, Shape r Ix1,
 Ord e) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (MArray (PrimState m) r Ix1 (e, Ix1), HashPSQ Ix1 e Ix1)
updateBNeighbour Ix1
b IntSet
s MArray RealWorld r Ix2 e
newDistMat MArray RealWorld r Ix1 (e, Ix1)
nNghbrU3 HashPSQ Ix1 e Ix1
pqU3

    -- If the problem has been reduced to a single cluster the algorithm is done and the final
    -- dendrogram can be obtained from the accumulator at index b. Otherwise join further.
    if IntSet -> Ix1
IntSet.size IntSet
newS forall a. Eq a => a -> a -> Bool
== Ix1
1
      then MArray RealWorld B Ix1 (Dendrogram e)
newAcc forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Ix1
b
      else forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r e, Manifest r (e, Ix1),
 Shape r Ix1, Fractional e, Ord e) =>
JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray RealWorld r Ix2 e
newDistMat MArray RealWorld r Ix1 (e, Ix1)
newNNghbr HashPSQ Ix1 e Ix1
newPQ IntSet
newS MArray RealWorld B Ix1 (Dendrogram e)
newAcc

-- | Obtain candidates for the clusters to join by looking at the minimal distance in the priority
-- queue and the neighbourlist. L 11-13
{-# SCC getJoinCandidates #-}
getJoinCandidates ::
  ( MonadThrow m,
    PrimMonad m,
    Manifest r (e, Ix1),
    Ord e
  ) =>
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.HashPSQ Ix1 e Ix1 ->
  m (Ix1, Ix1, e)
getJoinCandidates :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, Manifest r (e, Ix1), Ord e) =>
MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1 -> m (Ix1, Ix1, e)
getJoinCandidates MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq = do
  (Ix1
a, e
d, Ix1
_) <- case forall k p v.
(Hashable k, Ord k, Ord p) =>
HashPSQ k p v -> Maybe (k, p, v)
PQ.findMin HashPSQ Ix1 e Ix1
pq of
    Maybe (Ix1, e, Ix1)
Nothing -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue"
    Just (Ix1, e, Ix1)
v -> forall (m :: * -> *) a. Monad m => a -> m a
return (Ix1, e, Ix1)
v
  (e
_, Ix1
b) <- MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Ix1
a
  forall (m :: * -> *) a. Monad m => a -> m a
return (Ix1
a, Ix1
b, e
d)

-- | If the minimal distance @d@ found is not the distance between @a@ and @b@ recalculate the
-- neighbour list, update the priority queue and obtain a new set of a,b and a distance between them.
-- L 14-20.
{-# SCC recalculateNghbr #-}
recalculateNghbr ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    PrimState m ~ RealWorld,
    -- OuterSlice r Ix2 e,
    -- Manifest (R r) Ix1 e,
    Manifest r (e, Ix1),
    Manifest r e,
    Shape r Ix1,
    Ord e
  ) =>
  (Ix1, Ix1, e) ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.HashPSQ Ix1 e Ix1 ->
  m (Ix1, Ix1, e, MArray (PrimState m) r Ix1 (e, Ix1), PQ.HashPSQ Ix1 e Ix1)
recalculateNghbr :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r (e, Ix1), Manifest r e,
 Shape r Ix1, Ord e) =>
(Ix1, Ix1, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (Ix1, Ix1, e, MArray (PrimState m) r Ix1 (e, Ix1),
      HashPSQ Ix1 e Ix1)
recalculateNghbr (Ix1
cA, Ix1
cB, e
d) IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq = do
  e
dAB <- MArray (PrimState m) r Ix2 e
distMat forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
cA Ix1 -> Ix1 -> Ix2
:. Ix1
cB)
  if e
d forall a. Eq a => a -> a -> Bool
== e
dAB
    then forall (m :: * -> *) a. Monad m => a -> m a
return (Ix1
cA, Ix1
cB, e
d, MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr, HashPSQ Ix1 e Ix1
pq)
    else do
      -- Recalculate the nearest neighbours just on index cA. Consider only clusters, that were not
      -- merged yet.
      Array r Ix1 (e, Ix1)
dmRowA <- forall (m :: * -> *) r e.
(PrimMonad m, RealWorld ~ PrimState m, MonadThrow m,
 MonadUnliftIO m, Manifest r e, Manifest r (e, Ix1)) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Ix1 (e, Ix1))
searchRow Ix1
cA IntSet
s MArray (PrimState m) r Ix2 e
distMat forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
unsafeFreeze Comp
Par
      newNeighbourA :: (e, Ix1)
newNeighbourA@(e
minDistA, Ix1
_) <- forall (m :: * -> *) r ix e.
(MonadThrow m, Shape r ix, Source r e, Ord e) =>
Array r ix e -> m e
minimumM Array r Ix1 (e, Ix1)
dmRowA
      forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr Ix1
cA (e, Ix1)
newNeighbourA

      -- Update the priority queue at key cA with the new distance.
      let newPQ :: HashPSQ Ix1 e Ix1
newPQ = forall k p v.
(Ord k, Hashable k, Ord p) =>
(p -> p) -> k -> HashPSQ k p v -> HashPSQ k p v
pqAdjust (forall a b. a -> b -> a
const e
minDistA) Ix1
cA HashPSQ Ix1 e Ix1
pq

      -- Determine new a, b and d from the updated neighbour list and priority queue.
      (Ix1
a, e
newD, Ix1
_) <- case forall k p v.
(Hashable k, Ord k, Ord p) =>
HashPSQ k p v -> Maybe (k, p, v)
PQ.findMin HashPSQ Ix1 e Ix1
newPQ of
        Maybe (Ix1, e, Ix1)
Nothing -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue"
        Just (Ix1, e, Ix1)
v -> forall (m :: * -> *) a. Monad m => a -> m a
return (Ix1, e, Ix1)
v
      (e
_, Ix1
b) <- MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Ix1
a
      forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Manifest r (e, Ix1), Manifest r e,
 Shape r Ix1, Ord e) =>
(Ix1, Ix1, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (Ix1, Ix1, e, MArray (PrimState m) r Ix1 (e, Ix1),
      HashPSQ Ix1 e Ix1)
recalculateNghbr (Ix1
a, Ix1
b, e
newD) IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
newPQ

-- | Joins the selected clusters \(A\) and \(B\) and updates the dendrogram accumulator at index b.
-- A will not be removed so that the accumulator never shrinks.
-- L 21-24
{-# SCC joinClusters #-}
joinClusters ::
  ( MonadThrow m,
    PrimMonad m,
    Ord e
  ) =>
  Ix1 ->
  Ix1 ->
  e ->
  IntSet ->
  PQ.HashPSQ Ix1 e Ix1 ->
  DendroAccM m e ->
  m (IntSet, PQ.HashPSQ Ix1 e Ix1, DendroAccM m e)
joinClusters :: forall (m :: * -> *) e.
(MonadThrow m, PrimMonad m, Ord e) =>
Ix1
-> Ix1
-> e
-> IntSet
-> HashPSQ Ix1 e Ix1
-> DendroAccM m e
-> m (IntSet, HashPSQ Ix1 e Ix1, DendroAccM m e)
joinClusters Ix1
a Ix1
b e
d IntSet
s HashPSQ Ix1 e Ix1
pq DendroAccM m e
acc = do
  Dendrogram e
clA <- DendroAccM m e
acc forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Ix1
a
  let newPQ :: HashPSQ Ix1 e Ix1
newPQ = forall k p v.
(Hashable k, Ord k, Ord p) =>
HashPSQ k p v -> HashPSQ k p v
PQ.deleteMin HashPSQ Ix1 e Ix1
pq
  forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_
    DendroAccM m e
acc
    ( \Dendrogram e
clB ->
        forall (m :: * -> *) a. Monad m => a -> m a
return
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. BinTree (DendroNode e) -> Dendrogram e
Dendrogram
          forall a b. (a -> b) -> a -> b
$ forall e. e -> BinTree e -> BinTree e -> BinTree e
Node
            ( DendroNode
                { $sel:distance:DendroNode :: e
distance = e
d,
                  $sel:cluster:DendroNode :: IntSet
cluster = (forall e. DendroNode e -> IntSet
cluster forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. BinTree e -> e
root forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro forall a b. (a -> b) -> a -> b
$ Dendrogram e
clA) forall a. Semigroup a => a -> a -> a
<> (forall e. DendroNode e -> IntSet
cluster forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. BinTree e -> e
root forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro forall a b. (a -> b) -> a -> b
$ Dendrogram e
clB)
                }
            )
            (forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro Dendrogram e
clA)
            (forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro Dendrogram e
clB)
    )
    Ix1
b
  let newS :: IntSet
newS = Ix1 -> IntSet -> IntSet
IntSet.delete Ix1
a IntSet
s
  forall (m :: * -> *) a. Monad m => a -> m a
return (IntSet
newS, HashPSQ Ix1 e Ix1
newPQ, DendroAccM m e
acc)

-- | Update the distance matrix with a Lance-Williams update in the rows and columns of cluster b.
-- L 25-27
{-# SCC updateDistMat #-}
updateDistMat ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    Manifest r e,
    Fractional e
  ) =>
  JoinStrat e ->
  Ix1 ->
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  DendroAccM m e ->
  m (MArray (PrimState m) r Ix2 e)
updateDistMat :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m, Manifest r e,
 Fractional e) =>
JoinStrat e
-> Ix1
-> Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> DendroAccM m e
-> m (MArray (PrimState m) r Ix2 e)
updateDistMat JoinStrat e
js Ix1
a Ix1
b IntSet
s MArray (PrimState m) r Ix2 e
distMat DendroAccM m e
dendroAcc
  | Ix1
nDM forall a. Eq a => a -> a -> Bool
/= Ix1
nDM = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nDM) (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nCl)
  | Ix1
mDM forall a. Eq a => a -> a -> Bool
/= Ix1
nDM = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (forall ix. Index ix => ix -> Sz ix
Sz Ix1
mDM) (forall ix. Index ix => ix -> Sz ix
Sz Ix1
nDM)
  | Bool
otherwise = do
    e
dAB <- MArray (PrimState m) r Ix2 e
distMat forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
a Ix1 -> Ix1 -> Ix2
:. Ix1
b)
    Ix1
nA <- Ix1 -> m Ix1
clSize Ix1
a
    Ix1
nB <- Ix1 -> m Ix1
clSize Ix1
b
    forall r ix e (m :: * -> *) a.
(Load r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Ix1 Ix1
ixV forall a b. (a -> b) -> a -> b
$ \Ix1
ix -> do
      e
dAX <- MArray (PrimState m) r Ix2 e
distMat forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
a Ix1 -> Ix1 -> Ix2
:. Ix1
ix)
      Ix1
nX <- Ix1 -> m Ix1
clSize Ix1
ix
      forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_ MArray (PrimState m) r Ix2 e
distMat (\e
dBX -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall e.
Fractional e =>
JoinStrat e -> Ix1 -> Ix1 -> Ix1 -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Ix1
nA Ix1
nB Ix1
nX e
dAB e
dAX e
dBX) (Ix1
ix Ix1 -> Ix1 -> Ix2
:. Ix1
b)
      forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_ MArray (PrimState m) r Ix2 e
distMat (\e
dBX -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall e.
Fractional e =>
JoinStrat e -> Ix1 -> Ix1 -> Ix1 -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Ix1
nA Ix1
nB Ix1
nX e
dAB e
dAX e
dBX) (Ix1
b Ix1 -> Ix1 -> Ix2
:. Ix1
ix)
    forall (m :: * -> *) a. Monad m => a -> m a
return MArray (PrimState m) r Ix2 e
distMat
  where
    Sz (Ix1
mDM :. Ix1
nDM) = forall r e ix s.
(Manifest r e, Index ix) =>
MArray s r ix e -> Sz ix
sizeOfMArray MArray (PrimState m) r Ix2 e
distMat
    Sz Ix1
nCl = forall r e ix s.
(Manifest r e, Index ix) =>
MArray s r ix e -> Sz ix
sizeOfMArray DendroAccM m e
dendroAcc
    ixV :: Array U Ix1 Ix1
ixV = forall r e. Manifest r e => Comp -> [e] -> Vector r e
Massiv.fromList @U Comp
Par forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Ix1]
IntSet.toAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ix1 -> IntSet -> IntSet
IntSet.delete Ix1
b forall a b. (a -> b) -> a -> b
$ IntSet
s
    clSize :: Ix1 -> m Ix1
clSize Ix1
i = IntSet -> Ix1
IntSet.size forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. DendroNode e -> IntSet
cluster forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. BinTree e -> e
root forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DendroAccM m e
dendroAcc forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Ix1
i

-- | Updates the neighbourlist. All elements with a smaller index than a, that had a as a nearest
-- neighbour are blindly redirected to the union of a and b, now at index b.
-- L 28-32
{-# SCC redirectNeighbours #-}
redirectNeighbours ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    Manifest r (e, Ix1),
    Manifest r e
  ) =>
  Ix1 ->
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  m (MArray (PrimState m) r Ix1 (e, Ix1))
redirectNeighbours :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m, Manifest r (e, Ix1),
 Manifest r e) =>
Ix1
-> Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> m (MArray (PrimState m) r Ix1 (e, Ix1))
redirectNeighbours Ix1
a Ix1
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr = do
  forall r ix e (m :: * -> *) a.
(Load r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Ix1 Ix1
ixV forall a b. (a -> b) -> a -> b
$ \Ix1
ix ->
    forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_
      MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr
      ( \old :: (e, Ix1)
old@(e
_, Ix1
nghbrX) ->
          if Ix1
nghbrX forall a. Eq a => a -> a -> Bool
== Ix1
a
            then MArray (PrimState m) r Ix2 e
distMat forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
ix Ix1 -> Ix1 -> Ix2
:. Ix1
b) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \e
dXB -> forall (m :: * -> *) a. Monad m => a -> m a
return (e
dXB, Ix1
b)
            else forall (m :: * -> *) a. Monad m => a -> m a
return (e, Ix1)
old
      )
      Ix1
ix
  forall (m :: * -> *) a. Monad m => a -> m a
return MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr
  where
    ixV :: Array U Ix1 Ix1
ixV = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (forall a. Ord a => a -> a -> Bool
< Ix1
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Manifest r e => Comp -> [e] -> Vector r e
Massiv.fromList @U Comp
Par forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Ix1]
IntSet.toAscList forall a b. (a -> b) -> a -> b
$ IntSet
s

-- | Updates the list of nearest neighbours for all combinations that might have changed by
-- recalculation with the joined cluster AB at index b.
-- L 33-38
{-# SCC updateWithNewBDists #-}
updateWithNewBDists ::
  ( MonadThrow m,
    MonadUnliftIO m,
    PrimMonad m,
    Manifest r e,
    Manifest r (e, Ix1),
    Ord e
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.HashPSQ Ix1 e Ix1 ->
  m (MArray (PrimState m) r Ix1 (e, Ix1), PQ.HashPSQ Ix1 e Ix1)
updateWithNewBDists :: forall (m :: * -> *) r e.
(MonadThrow m, MonadUnliftIO m, PrimMonad m, Manifest r e,
 Manifest r (e, Ix1), Ord e) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (MArray (PrimState m) r Ix1 (e, Ix1), HashPSQ Ix1 e Ix1)
updateWithNewBDists Ix1
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq = do
  TVar (HashPSQ Ix1 e Ix1)
pqT <- forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO HashPSQ Ix1 e Ix1
pq
  forall r ix e (m :: * -> *) a.
(Load r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Ix1 Ix1
ixV forall a b. (a -> b) -> a -> b
$ \Ix1
ix -> do
    e
dBX <- MArray (PrimState m) r Ix2 e
distMat forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
ix Ix1 -> Ix1 -> Ix2
:. Ix1
b)
    HashPSQ Ix1 e Ix1
currentPQ <- forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (HashPSQ Ix1 e Ix1)
pqT
    e
minDistX <- case forall k p v.
(Ord k, Hashable k, Ord p) =>
k -> HashPSQ k p v -> Maybe (p, v)
PQ.lookup Ix1
ix HashPSQ Ix1 e Ix1
currentPQ of
      Maybe (e, Ix1)
Nothing -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue."
      Just (e
p, Ix1
_v) -> forall (m :: * -> *) a. Monad m => a -> m a
return e
p
    if e
dBX forall a. Ord a => a -> a -> Bool
< e
minDistX
      then do
        forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr Ix1
ix (e
dBX, Ix1
b)
        forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TVar a -> a -> STM ()
writeTVar TVar (HashPSQ Ix1 e Ix1)
pqT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k p v.
(Ord k, Hashable k, Ord p) =>
(p -> p) -> k -> HashPSQ k p v -> HashPSQ k p v
pqAdjust (forall a b. a -> b -> a
const e
dBX) Ix1
ix forall a b. (a -> b) -> a -> b
$ HashPSQ Ix1 e Ix1
currentPQ
      else forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TVar a -> a -> STM ()
writeTVar TVar (HashPSQ Ix1 e Ix1)
pqT forall a b. (a -> b) -> a -> b
$ HashPSQ Ix1 e Ix1
currentPQ

  HashPSQ Ix1 e Ix1
newPQ <- forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (HashPSQ Ix1 e Ix1)
pqT
  forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr, HashPSQ Ix1 e Ix1
newPQ)
  where
    ixV :: Array U Ix1 Ix1
ixV = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
Massiv.sfilter (forall a. Ord a => a -> a -> Bool
< Ix1
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Manifest r e => Comp -> [e] -> Vector r e
Massiv.fromList @U Comp
Par forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Ix1]
IntSet.toAscList forall a b. (a -> b) -> a -> b
$ IntSet
s

-- | Updates the list of nearest neighbours and the priority queue at key b.
-- L 39-40
{-# SCC updateBNeighbour #-}
updateBNeighbour ::
  ( MonadThrow m,
    PrimMonad m,
    RealWorld ~ PrimState m,
    MonadUnliftIO m,
    Manifest r (e, Ix1),
    Manifest r e,
    Shape r Ix1,
    Ord e
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.HashPSQ Ix1 e Ix1 ->
  m (MArray (PrimState m) r Ix1 (e, Ix1), PQ.HashPSQ Ix1 e Ix1)
updateBNeighbour :: forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, RealWorld ~ PrimState m,
 MonadUnliftIO m, Manifest r (e, Ix1), Manifest r e, Shape r Ix1,
 Ord e) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Ix1 (e, Ix1)
-> HashPSQ Ix1 e Ix1
-> m (MArray (PrimState m) r Ix1 (e, Ix1), HashPSQ Ix1 e Ix1)
updateBNeighbour Ix1
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr HashPSQ Ix1 e Ix1
pq =
  if Ix1
b forall a. Ord a => a -> a -> Bool
>= Ix1
nNeighbours
    then forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr, HashPSQ Ix1 e Ix1
pq)
    else do
      Array r Ix1 (e, Ix1)
rowAB <- forall (m :: * -> *) r e.
(PrimMonad m, RealWorld ~ PrimState m, MonadThrow m,
 MonadUnliftIO m, Manifest r e, Manifest r (e, Ix1)) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Ix1 (e, Ix1))
searchRow Ix1
b IntSet
s MArray (PrimState m) r Ix2 e
distMat forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
unsafeFreeze Comp
Par
      newNeighbourB :: (e, Ix1)
newNeighbourB@(e
distB, Ix1
neighbourB) <- forall (m :: * -> *) r ix e.
(MonadThrow m, Shape r ix, Source r e, Ord e) =>
Array r ix e -> m e
minimumM Array r Ix1 (e, Ix1)
rowAB
      forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr Ix1
b (e, Ix1)
newNeighbourB
      let newPQ :: HashPSQ Ix1 e Ix1
newPQ = forall k p v.
(Ord k, Hashable k, Ord p) =>
(p -> p) -> k -> HashPSQ k p v -> HashPSQ k p v
pqAdjust (forall a b. a -> b -> a
const e
distB) Ix1
neighbourB HashPSQ Ix1 e Ix1
pq
      forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr, HashPSQ Ix1 e Ix1
newPQ)
  where
    Sz Ix1
nNeighbours = forall r e ix s.
(Manifest r e, Index ix) =>
MArray s r ix e -> Sz ix
sizeOfMArray MArray (PrimState m) r Ix1 (e, Ix1)
nNghbr

-- | Find the nearest neighbour for each point from a distance matrix. For each point it stores the
-- minimum distance and the index of the other point, that is the nearest neighbour but at a higher
-- index.
{-# SCC nearestNeighbours #-}
nearestNeighbours ::
  ( MonadThrow m,
    Manifest r e,
    Manifest r (e, Ix1),
    Load r Ix1 e,
    -- OuterSlice r Ix2 e,
    -- Source (R r) Ix1 e,
    Ord e,
    Unbox e
  ) =>
  Matrix r e ->
  m (Vector r (e, Ix1))
nearestNeighbours :: forall (m :: * -> *) r e.
(MonadThrow m, Manifest r e, Manifest r (e, Ix1), Load r Ix1 e,
 Ord e, Unbox e) =>
Matrix r e -> m (Vector r (e, Ix1))
nearestNeighbours Matrix r e
distMat
  | Ix1
m forall a. Eq a => a -> a -> Bool
/= Ix1
n = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Distance matrix is not square"
  | Ix1
m forall a. Eq a => a -> a -> Bool
== Ix1
0 = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Distance matrix is empty"
  | Bool
otherwise =
    let rows :: Array B Ix1 (Array r Ix1 e)
rows = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e.
(Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Array D Ix1 (Array r (Lower ix) e)
outerSlices forall a b. (a -> b) -> a -> b
$ Matrix r e
distMat
        minDistIx :: Array D Ix1 (e, Ix1)
minDistIx =
          forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
Massiv.imap (\Ix1
i Array r Ix1 e
v -> forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e (m :: * -> *).
(Manifest r e, MonadThrow m, Ord e) =>
Ix1 -> Vector r e -> m (e, Ix1)
minDistAtVec Ix1
i forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall a b. (a -> b) -> a -> b
$ Array r Ix1 e
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Source r e => Vector r e -> Vector r e
init forall a b. (a -> b) -> a -> b
$ Array B Ix1 (Array r Ix1 e)
rows
     in forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute forall a b. (a -> b) -> a -> b
$ Array D Ix1 (e, Ix1)
minDistIx
  where
    Sz (Ix1
m :. Ix1
n) = forall r ix e. Size r => Array r ix e -> Sz ix
size Matrix r e
distMat

-- | Make a search row for distances. Takes row x from a distance matrix and zips them with their
-- column index. Then keeps only the valid elements of the row, that are still part of the available
-- points. A minimum or maximum search can be performed on the resulting vector and a valid pair of
-- distance and index can be obtained.
searchRow ::
  ( PrimMonad m,
    RealWorld ~ PrimState m,
    MonadThrow m,
    MonadUnliftIO m,
    Manifest r e,
    Manifest r (e, Ix1)
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  m (MArray (PrimState m) r Ix1 (e, Ix1))
searchRow :: forall (m :: * -> *) r e.
(PrimMonad m, RealWorld ~ PrimState m, MonadThrow m,
 MonadUnliftIO m, Manifest r e, Manifest r (e, Ix1)) =>
Ix1
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Ix1 (e, Ix1))
searchRow Ix1
x IntSet
s MArray (PrimState m) r Ix2 e
dm =
  forall r ix e (m :: * -> *).
(MonadUnliftIO m, Manifest r e, Index ix) =>
Comp -> Sz ix -> (ix -> m e) -> m (MArray RealWorld r ix e)
makeMArray Comp
Par (forall r ix e. Size r => Array r ix e -> Sz ix
size Array U Ix1 Ix1
ixV) forall a b. (a -> b) -> a -> b
$ \Ix1
ix -> do
    Ix1
dmIx <- Array U Ix1 Ix1
ixV forall r ix e (m :: * -> *).
(Index ix, Manifest r e, MonadThrow m) =>
Array r ix e -> ix -> m e
!? Ix1
ix
    (e, Ix1)
val <- (MArray (PrimState m) r Ix2 e
dm forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Ix1
x Ix1 -> Ix1 -> Ix2
:. Ix1
dmIx)) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \e
dist -> forall (m :: * -> *) a. Monad m => a -> m a
return (e
dist, Ix1
dmIx)
    forall (m :: * -> *) a. Monad m => a -> m a
return (e, Ix1)
val
  where
    ixV :: Vector U Ix1
    ixV :: Array U Ix1 Ix1
ixV = forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (forall a. Ord a => a -> a -> Bool
> Ix1
x) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r e. Manifest r e => Comp -> [e] -> Vector r e
Massiv.fromList @U Comp
Par forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Ix1]
IntSet.toAscList forall a b. (a -> b) -> a -> b
$ IntSet
s