-- |
-- Module      :  Statistics.Covariance.RaoBlackwellLedoitWolf
-- Description :  Improved shrinkage based covariance estimator
-- Copyright   :  (c) 2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Fri Sep 10 09:26:58 2021.
module Statistics.Covariance.RaoBlackwellLedoitWolf
  ( raoBlackwellLedoitWolf,
  )
where

import qualified Numeric.LinearAlgebra as L
import Statistics.Covariance.Internal.Tools

-- | Improved shrinkage based covariance estimator by Ledoit and Wolf using the
-- Rao-Blackwell theorem.
--
-- See Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O., Shrinkage algorithms
-- for mmse covariance estimation, IEEE Transactions on Signal Processing,
-- 58(10), 5016–5029 (2010). http://dx.doi.org/10.1109/tsp.2010.2053029.
--
-- Return 'Left' if
--
-- - only one sample is available.
--
-- - no parameters are available.
--
-- NOTE: This function may call 'error' due to partial library functions.
raoBlackwellLedoitWolf ::
  -- | Sample data matrix of dimension \(n \times p\), where \(n\) is the number
  -- of samples (rows), and \(p\) is the number of parameters (columns).
  L.Matrix Double ->
  Either String (L.Herm Double)
raoBlackwellLedoitWolf :: Matrix Double -> Either String (Herm Double)
raoBlackwellLedoitWolf Matrix Double
xs
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = String -> Either String (Herm Double)
forall a b. a -> Either a b
Left String
"raoBlackwellLedoitWolf: Need more than one sample."
  | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = String -> Either String (Herm Double)
forall a b. a -> Either a b
Left String
"raoBlackwellLedoitWolf: Need at least one parameter."
  -- Rao-Blackwell Ledoit and Wolf shrinkage estimator of the covariance matrix
  -- (Equation 16).
  | Bool
otherwise = Herm Double -> Either String (Herm Double)
forall a b. b -> Either a b
Right (Herm Double -> Either String (Herm Double))
-> Herm Double -> Either String (Herm Double)
forall a b. (a -> b) -> a -> b
$ Double -> Herm Double -> Double -> Herm Double -> Herm Double
shrinkWith Double
rho Herm Double
sigma Double
mu Herm Double
im
  where
    n :: Int
n = Matrix Double -> Int
forall t. Matrix t -> Int
L.rows Matrix Double
xs
    p :: Int
p = Matrix Double -> Int
forall t. Matrix t -> Int
L.cols Matrix Double
xs
    (Vector Double
_, Herm Double
sigma) = Matrix Double -> (Vector Double, Herm Double)
L.meanCov Matrix Double
xs
    im :: Herm Double
im = Matrix Double -> Herm Double
forall t. Matrix t -> Herm t
L.trustSym (Matrix Double -> Herm Double) -> Matrix Double -> Herm Double
forall a b. (a -> b) -> a -> b
$ Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
L.ident Int
p
    -- Trace and squared trace of sigma.
    trS :: Double
trS = Matrix Double -> Double
trace (Matrix Double -> Double) -> Matrix Double -> Double
forall a b. (a -> b) -> a -> b
$ Herm Double -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Herm Double
sigma
    tr2S :: Double
tr2S = Double
trS Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
trS
    -- Trace of (sigma squared).
    s2 :: Matrix Double
s2 = let s :: Matrix Double
s = Herm Double -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Herm Double
sigma in Matrix Double
s Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
L.<> Matrix Double
s
    trS2 :: Double
trS2 = Matrix Double -> Double
trace Matrix Double
s2
    -- Shrinkage factor (Equation 17, and 19).
    n' :: Double
n' = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
    p' :: Double
p' = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p
    rhoNominator :: Double
rhoNominator = ((Double
n' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n') Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
trS2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
tr2S
    rhoDenominator :: Double
rhoDenominator = (Double
n' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
trS2 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Fractional a => a -> a
recip Double
p' Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
tr2S)
    rho' :: Double
rho' = Double
rhoNominator Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
rhoDenominator
    rho :: Double
rho = Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
rho' Double
1.0
    -- Scaling factor of the identity matrix (Equation 3).
    mu :: Double
mu = Double
trS Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
p'