-- |
-- Module      :  Statistics.Covariance.Internal.Tools
-- Description :  Common functions
-- Copyright   :  (c) 2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Fri Sep 10 09:47:00 2021.
module Statistics.Covariance.Internal.Tools
  ( centerWith,
    shrinkWith,
    trace,
  )
where

import qualified Data.Vector.Storable as VS
import qualified Numeric.LinearAlgebra as L
import qualified Numeric.LinearAlgebra.Devel as L

centerWith ::
  -- Mean vector of dimension P.
  L.Vector Double ->
  -- Data matrix of dimension N x P.
  L.Matrix Double ->
  -- Data matrix with means 0.
  L.Matrix Double
centerWith :: Vector Double -> Matrix Double -> Matrix Double
centerWith Vector Double
ms = ((Int, Int) -> Double -> Double) -> Matrix Double -> Matrix Double
forall a b.
(Element a, Storable b) =>
((Int, Int) -> a -> b) -> Matrix a -> Matrix b
L.mapMatrixWithIndex (\(Int
_, Int
j) Double
x -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- Vector Double
ms Vector Double -> Int -> Double
forall a. Storable a => Vector a -> Int -> a
VS.! Int
j)

-- Shrinkage a covariance matrix.
shrinkWith ::
  -- Shrinkage factor.
  Double ->
  -- Sample covariance matrix.
  L.Herm Double ->
  -- Scale of identity matrix (trace of sample covariance matrix divided by
  -- dimension). See Chen2010b, Equation 3.
  Double ->
  -- Identity matrix.
  L.Herm Double ->
  L.Herm Double
shrinkWith :: Double -> Herm Double -> Double -> Herm Double -> Herm Double
shrinkWith Double
rho Herm Double
sigma Double
mu Herm Double
im
  | Double
rho Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.0 = [Char] -> Herm Double
forall a. HasCallStack => [Char] -> a
error [Char]
"shrinkWith: Bug! Shrinkage factor is negative."
  | Double
rho Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
1.0 = [Char] -> Herm Double
forall a. HasCallStack => [Char] -> a
error [Char]
"shrinkWith: Bug! Shrinkage factor is larger than 1.0."
  | Double
mu Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.0 = [Char] -> Herm Double
forall a. HasCallStack => [Char] -> a
error [Char]
"shrinkWith: Bug! Scaling factor of identity matrix is negative."
  | Double
rho Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
1.0 = Matrix Double -> Herm Double
forall t. Matrix t -> Herm t
L.trustSym (Matrix Double -> Herm Double) -> Matrix Double -> Herm Double
forall a b. (a -> b) -> a -> b
$ Double -> Matrix Double -> Matrix Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
mu (Herm Double -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Herm Double
im)
  | Bool
otherwise =
    Matrix Double -> Herm Double
forall t. Matrix t -> Herm t
L.trustSym (Matrix Double -> Herm Double) -> Matrix Double -> Herm Double
forall a b. (a -> b) -> a -> b
$
      Double -> Matrix Double -> Matrix Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale (Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
rho) (Herm Double -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Herm Double
sigma)
        Matrix Double -> Matrix Double -> Matrix Double
forall a. Num a => a -> a -> a
+ Double -> Matrix Double -> Matrix Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale (Double
rho Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
mu) (Herm Double -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Herm Double
im)

-- Trace of a matrix.
trace :: L.Matrix Double -> Double
trace :: Matrix Double -> Double
trace = Vector Double -> Double
forall (c :: * -> *) e. Container c e => c e -> e
L.sumElements (Vector Double -> Double)
-> (Matrix Double -> Vector Double) -> Matrix Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix Double -> Vector Double
forall t. Element t => Matrix t -> Vector t
L.takeDiag