{-# LANGUAGE BangPatterns #-}
module Math.Clustering.Spectral.Dense
    ( spectralClusterKNorm
    , spectralClusterNorm
    , spectralNorm
    , getDegreeMatrix
    , AdjacencyMatrix (..)
    , LabelVector (..)
    , B (..)
    , B1 (..)
    , B2 (..)
    , spectral
    , spectralCluster
    , spectralClusterK
    , getB
    , b1ToB2
    , getSimilarityFromB2
    ) where
import Data.Bool (bool)
import Data.Function (on)
import Data.List (sortBy, maximumBy, transpose)
import Data.Maybe (fromMaybe)
import Safe (headMay)
import qualified AI.Clustering.KMeans as K
import qualified Data.Map.Strict as Map
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as U
import qualified Numeric.LinearAlgebra as H
import qualified Numeric.LinearAlgebra.Devel as H
import qualified Statistics.Quantile as S
import qualified Numeric.LinearAlgebra.SVD.SVDLIBC as SVD
type LabelVector     = H.Vector Double
type AdjacencyMatrix = H.Matrix Double
newtype B1 = B1 { unB1 :: H.Matrix Double } deriving (Show)
newtype B2 = B2 { unB2 :: H.Matrix Double } deriving (Show)
newtype D  = D { unD :: H.Vector Double } deriving (Show)
newtype C  = C { unC :: H.Matrix Double } deriving (Show)
newtype B  = B { unB :: H.Matrix Double } deriving (Show)
newtype Diag  = Diag { unDiag :: H.Vector Double } deriving (Show)
epsilonZero :: Double -> Double
epsilonZero x = if abs x < 1e-12 then 0 else x
cimap :: (Int -> Int -> Double -> Double) -> H.Matrix Double -> H.Matrix Double
cimap f mat = H.assoc (H.size mat) 0
            . concatMap (\ (!i, xs)
                        -> fmap (\ (!j, !x)
                                -> ( (i, j)
                                  , f i j x
                                  )
                                )
                                xs
                        )
            . zip [0..]
            . fmap (zip [0..])
            . H.toLists
            $ mat
b1ToB2 :: B1 -> B2
b1ToB2 (B1 b1) =
    B2
      . cimap (\ !i !j !x -> (log (fromIntegral n / (fromMaybe (error "Missing degree for observation. This would lead to divide by 0 error.") $ dVec VS.!? j))) * x)
      $ b1
  where
    dVec :: H.Vector Double
    dVec = H.fromList
         . fmap (H.sumElements . H.step)
         . H.toColumns
         $ b1
    n = H.rows b1
    m = H.cols b1
b2ToB :: B2 -> B
b2ToB (B2 b2) =
    B . cimap (\ !i !j !x -> x / (fromMaybe (error "Missing degree for observation. This would lead to divide by 0 error.") $ eVec VS.!? i)) $ b2
  where
    eVec :: H.Vector Double
    eVec = H.fromList . fmap H.norm_2 . H.toRows $ b2
    n = H.rows b2
    m = H.cols b2
bToD :: B -> D
bToD (B b) = D
           . H.flatten
           $ (H.cmap abs b)
        H.<> ((H.cmap abs $ H.tr b) H.<> ((n H.>< 1) [1,1..]))
  where
    n = H.rows b
bdToC :: B -> D -> C
bdToC (B b) (D d) = C $ diagMatMult (Diag $ H.cmap (\x -> x ** (- 1 / 2)) d) b
secondLeft :: Int -> Int -> H.Matrix Double -> [H.Vector Double]
secondLeft n e m = take e
                 . drop (n - 1)
                 . H.toColumns
                 . (\(u, _, _) -> u)
                 . H.svd
                 $ m
getB :: Bool -> H.Matrix Double -> B
getB True = b2ToB . b1ToB2 . B1
getB False = b2ToB . B2
spectral :: Int -> Int -> B -> [H.Vector Double]
spectral n e b
    | e < 1     = error "Less than 1 eigenvector chosen for clustering."
    | n < 1 = error "N < 1, cannot go before first eigenvector."
    | otherwise =
        fmap (H.cmap epsilonZero) . secondLeft n e . unC . bdToC b . bToD $ b
spectralCluster :: B -> LabelVector
spectralCluster (B b)
  | H.rows b < 1  = H.fromList []
  | H.rows b == 1 = H.fromList [0]
  | otherwise     = H.cmap (bool 0 1 . (>= 0))
                  . mconcat
                  . spectral 2 1
                  $ B b
spectralClusterK :: Int -> Int -> B -> LabelVector
spectralClusterK e k (B b)
  | H.rows b < 1  = H.fromList []
  | H.rows b == 1 = H.fromList [0]
  | otherwise     = kmeansVec k . spectral 2 e $ B b
kmeansVec :: Int -> [H.Vector Double] -> LabelVector
kmeansVec k = consensusKmeans 100
            . V.fromList
            . fmap U.convert
            . H.toRows
            . H.fromColumns
            . fmap H.normalize 
            . H.toColumns
            . H.fromRows
  where
consensusKmeans :: Int -> V.Vector (U.Vector Double) -> LabelVector
consensusKmeans x vs = H.fromList
                     . fmap (fromIntegral . mostCommon)
                     . transpose
                     . fmap kmeansFunc
                     $ [1 .. fromIntegral x]
  where
    kmeansFunc run =
      (\xs -> if headMay xs == Just 1 then fmap (bool 0 1 . (== 0)) xs else xs)
        . U.toList
        . K.membership
        . K.kmeansBy 2 vs id
        $ K.defaultKMeansOpts
            { K.kmeansMethod = K.Forgy
            , K.kmeansClusters = False
            , K.kmeansSeed = U.fromList [run]
            }
mostCommon :: (Ord a) => [a] -> a
mostCommon [] = error "Cannot find most common element of empty list."
mostCommon [x] = x
mostCommon xs = fst
               . maximumBy (compare `on` snd)
               . Map.toAscList
               . Map.fromListWith (+)
               . flip zip [1,1..]
               $ xs
getSimilarityFromB2 :: B2 -> Int -> Int -> Double
getSimilarityFromB2 (B2 b2) i j =
    H.dot (H.flatten $ b2 H.? [i]) (H.flatten $ b2 H.? [j])
        / (H.norm_2 (H.flatten $ b2 H.? [i]) * H.norm_2 (H.flatten $ b2 H.? [j]))
spectralClusterKNorm :: Int -> Int -> AdjacencyMatrix -> LabelVector
spectralClusterKNorm e k mat
  | H.rows mat < 1  = H.fromList []
  | H.rows mat == 1 = H.fromList [0]
  | otherwise       = kmeansVec k
                    . spectralNorm 2 e
                    $ mat
spectralClusterNorm :: AdjacencyMatrix -> LabelVector
spectralClusterNorm mat
  | H.rows mat < 1  = H.fromList []
  | H.rows mat == 1 = H.fromList [0]
  | otherwise       = H.cmap (bool 0 1 . (>= 0))
                    . mconcat
                    . spectralNorm 2 1
                    $ mat
spectralNorm :: Int -> Int -> AdjacencyMatrix -> [H.Vector Double]
spectralNorm n e mat
    | e < 1 = error "Less than 1 eigenvector chosen for clustering."
    | n < 1 = error "N < 1, cannot go before first eigenvector."
    | otherwise = H.toRows
                . H.cmap epsilonZero 
                . flip (H.??) (H.All, H.TakeLast e)
                . flip (H.??) (H.All, H.DropLast (n - 1))
                . snd
                . H.eigSH
                $ lNorm
  where
    lNorm = H.trustSym $ i - (matDiagMult (diagMatMult invD mat) invD)
    invD  = Diag
          . H.cmap (\x -> if x == 0 then x else x ** (- 1 / 2))
          . getDegreeVector
          $ mat
    i     = H.ident . H.rows $ mat
sortMatrixByVec :: H.Vector Double -> H.Matrix Double -> H.Matrix Double
sortMatrixByVec xs mat = mat H.¿ sortedIdx
  where
    sortedIdx =
      reverse . fmap fst . sortBy (compare `on` snd) . zip [0..] . H.toList $ xs
getDegreeMatrix :: AdjacencyMatrix -> Diag
getDegreeMatrix = Diag . getDegreeVector
getDegreeVector :: AdjacencyMatrix -> H.Vector Double
getDegreeVector = H.vector . fmap (H.sumElements . H.cmap abs) . H.toRows
diagMatMult :: Diag -> H.Matrix Double -> H.Matrix Double
diagMatMult (Diag d) = cimap mult
  where
    mult i _ v = v * (d H.! i)
matDiagMult :: H.Matrix Double -> Diag -> H.Matrix Double
matDiagMult m (Diag d) = cimap mult m
  where
    mult _ j v = v * (d H.! j)