{-# LANGUAGE FlexibleContexts #-}

-- |
-- Description :  Rate matrix helper functions
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPLv3
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable (not tested)
--
-- Some helper functions that come handy when working with rate matrices of
-- continuous-time discrete-state Markov processes.
--
-- * Changelog
--
-- To be imported qualified.
module ELynx.MarkovProcess.RateMatrix
  ( RateMatrix,
    ExchangeabilityMatrix,
    StationaryDistribution,
    isValid,
    normalizeSD,
    totalRate,
    totalRateWith,
    normalize,
    normalizeWith,
    setDiagonal,
    toExchangeabilityMatrix,
    fromExchangeabilityMatrix,
    getStationaryDistribution,
    exchFromListLower,
    exchFromListUpper,
  )
where

import qualified Data.Vector.Storable as V
import Numeric.LinearAlgebra hiding (normalize)
import Numeric.SpecFunctions
import Prelude hiding ((<>))

-- | A rate matrix is just a real matrix.
type RateMatrix = Matrix R

-- | A matrix of exchangeabilities, we have q = e * pi, where q is a rate
-- matrix, e is the exchangeability matrix and pi is the diagonal matrix
-- containing the stationary frequency distribution.
type ExchangeabilityMatrix = Matrix R

-- | Stationary distribution of a rate matrix.
type StationaryDistribution = Vector R

epsRelaxed :: Double
epsRelaxed :: Double
epsRelaxed = Double
1e-5

-- | True if distribution sums to 1.0.
isValid :: StationaryDistribution -> Bool
isValid :: Vector Double -> Bool
isValid Vector Double
d = Double
epsRelaxed forall a. Ord a => a -> a -> Bool
> forall a. Num a => a -> a
abs (forall a. Normed a => a -> Double
norm_1 Vector Double
d forall a. Num a => a -> a -> a
- Double
1.0)

-- | Normalize a stationary distribution so that the elements sum to 1.0.
normalizeSD :: StationaryDistribution -> StationaryDistribution
normalizeSD :: Vector Double -> Vector Double
normalizeSD Vector Double
d = Vector Double
d forall a. Fractional a => a -> a -> a
/ forall (c :: * -> *) e. Container c e => e -> c e
scalar (forall a. Normed a => a -> Double
norm_1 Vector Double
d)

matrixSetDiagToZero :: Matrix R -> Matrix R
matrixSetDiagToZero :: Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m = Matrix Double
m forall a. Num a => a -> a -> a
- forall a. (Num a, Element a) => Vector a -> Matrix a
diag (forall t. Element t => Matrix t -> Vector t
takeDiag Matrix Double
m)
{-# INLINE matrixSetDiagToZero #-}

-- | Get average number of substitutions per unit time.
totalRateWith :: StationaryDistribution -> RateMatrix -> Double
totalRateWith :: Vector Double -> Matrix Double -> Double
totalRateWith Vector Double
d Matrix Double
m = forall a. Normed a => a -> Double
norm_1 forall a b. (a -> b) -> a -> b
$ Vector Double
d forall t. Numeric t => Vector t -> Matrix t -> Vector t
<# Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m

-- | Get average number of substitutions per unit time.
totalRate :: RateMatrix -> Double
totalRate :: Matrix Double -> Double
totalRate Matrix Double
m = Vector Double -> Matrix Double -> Double
totalRateWith (Matrix Double -> Vector Double
getStationaryDistribution Matrix Double
m) Matrix Double
m

-- | Normalizes a Markov process generator such that one event happens per unit
-- time. Calculates stationary distribution from rate matrix.
normalize :: RateMatrix -> RateMatrix
normalize :: Matrix Double -> Matrix Double
normalize Matrix Double
m = Vector Double -> Matrix Double -> Matrix Double
normalizeWith (Matrix Double -> Vector Double
getStationaryDistribution Matrix Double
m) Matrix Double
m

-- | Normalizes a Markov process generator such that one event happens per unit
-- time. Faster, but stationary distribution has to be given.
normalizeWith :: StationaryDistribution -> RateMatrix -> RateMatrix
normalizeWith :: Vector Double -> Matrix Double -> Matrix Double
normalizeWith Vector Double
d Matrix Double
m = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
1.0 forall a. Fractional a => a -> a -> a
/ Vector Double -> Matrix Double -> Double
totalRateWith Vector Double
d Matrix Double
m) Matrix Double
m

-- | Set the diagonal entries of a matrix such that the rows sum to 0.
setDiagonal :: RateMatrix -> RateMatrix
setDiagonal :: Matrix Double -> Matrix Double
setDiagonal Matrix Double
m = Matrix Double
diagZeroes forall a. Num a => a -> a -> a
- forall a. (Num a, Element a) => Vector a -> Matrix a
diag (forall a. Storable a => [a] -> Vector a
fromList [Double]
rowSums)
  where
    diagZeroes :: Matrix Double
diagZeroes = Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m
    rowSums :: [Double]
rowSums = forall a b. (a -> b) -> [a] -> [b]
map forall a. Normed a => a -> Double
norm_1 forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> [Vector t]
toRows Matrix Double
diagZeroes

-- | Extract the exchangeability matrix from a rate matrix.
toExchangeabilityMatrix ::
  RateMatrix -> StationaryDistribution -> ExchangeabilityMatrix
toExchangeabilityMatrix :: Matrix Double -> Vector Double -> Matrix Double
toExchangeabilityMatrix Matrix Double
m Vector Double
f = Matrix Double
m forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> forall a. (Num a, Element a) => Vector a -> Matrix a
diag Vector Double
oneOverF
  where
    oneOverF :: Vector Double
oneOverF = forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
cmap (Double
1.0 forall a. Fractional a => a -> a -> a
/) Vector Double
f

-- | Convert exchangeability matrix to rate matrix.
fromExchangeabilityMatrix ::
  ExchangeabilityMatrix -> StationaryDistribution -> RateMatrix
fromExchangeabilityMatrix :: Matrix Double -> Vector Double -> Matrix Double
fromExchangeabilityMatrix Matrix Double
em Vector Double
d = Matrix Double -> Matrix Double
setDiagonal forall a b. (a -> b) -> a -> b
$ Matrix Double
em forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> forall a. (Num a, Element a) => Vector a -> Matrix a
diag Vector Double
d

eps :: Double
eps :: Double
eps = Double
1e-12

normalizeSumVec :: V.Vector Double -> V.Vector Double
normalizeSumVec :: Vector Double -> Vector Double
normalizeSumVec Vector Double
v = forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (forall a. Fractional a => a -> a -> a
/ Double
s) Vector Double
v
  where
    s :: Double
s = forall a. (Storable a, Num a) => Vector a -> a
V.sum Vector Double
v
{-# INLINE normalizeSumVec #-}

-- | Get stationary distribution from 'RateMatrix'. Involves eigendecomposition.
-- If the given matrix does not satisfy the required properties of transition
-- rate matrices and no eigenvector with an eigenvalue nearly equal to 0 is
-- found, an error is thrown. Is there an easier way to calculate the stationary
-- distribution or a better way to handle errors (of course I could use the
-- Maybe monad, but then the error report is just delayed to the calling
-- function)?
getStationaryDistribution :: RateMatrix -> StationaryDistribution
getStationaryDistribution :: Matrix Double -> Vector Double
getStationaryDistribution Matrix Double
m =
  if Double
eps forall a. Ord a => a -> a -> Bool
> forall a. Num a => a -> a
abs (forall a. RealFloat a => Complex a -> a
magnitude (Vector (Complex Double)
eVals forall c t. Indexable c t => c -> Int -> t
! Int
i))
    then Vector Double -> Vector Double
normalizeSumVec Vector Double
distReal
    else forall a. HasCallStack => [Char] -> a
error [Char]
"getStationaryDistribution: Could not retrieve stationary distribution."
  where
    (Vector (Complex Double)
eVals, Matrix (Complex Double)
eVecs) = forall t.
Field t =>
Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
eig (forall m mt. Transposable m mt => m -> mt
tr Matrix Double
m)
    i :: IndexOf Vector
i = forall (c :: * -> *) e. Container c e => c e -> IndexOf c
minIndex Vector (Complex Double)
eVals
    distComplex :: Vector (Complex Double)
distComplex = forall t. Element t => Matrix t -> [Vector t]
toColumns Matrix (Complex Double)
eVecs forall a. [a] -> Int -> a
!! Int
i
    distReal :: Vector Double
distReal = forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
cmap forall a. Complex a -> a
realPart Vector (Complex Double)
distComplex

-- The next functions tackle the somewhat trivial, but not easily solvable
-- problem of converting a triangular matrix (excluding the diagonal) given as a
-- list into a symmetric matrix. The diagonal entries are set to zero.

-- Lower triangular matrix. This is how the exchangeabilities are specified in
-- PAML. Conversion from matrix indices (i,j) to list index k.
--
-- (i,j) k
--
-- (0,0) -
-- (1,0) 0  (1,1) -
-- (2,0) 1  (2,1) 2  (2,2) -
-- (3,0) 3  (3,1) 4  (3,2) 5 (3,3) -
-- (4,0) 6  (4,1) 7  (4,2) 8 (4,3) 9 (4,4) -
--   .
--   .
--   .
--
-- k = (i choose 2) + j.
ijToKLower :: Int -> Int -> Int
ijToKLower :: Int -> Int -> Int
ijToKLower Int
i Int
j
  | Int
i forall a. Ord a => a -> a -> Bool
> Int
j = forall a b. (RealFrac a, Integral b) => a -> b
round (Int
i Int -> Int -> Double
`choose` Int
2) forall a. Num a => a -> a -> a
+ Int
j
  | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"ijToKLower: not defined for upper triangular matrix."

-- Upper triangular matrix. Conversion from matrix indices (i,j) to list index
-- k. Matrix is square of size n.
--
-- (i,j) k
--
-- (0,0) -  (0,1) 0  (0,2) 1    (0,3) 2     (0,4) 3     ...
--          (1,1) -  (1,2) n-1  (1,3) n     (1,4) n+1
--                   (2,2) -    (2,3) 2n-3  (2,4) 2n-2
--                              (3,3) -     (3,4) 3n-6
--                                          (4,4) -
--                                                      ...
--
-- k = i*(n-2) - (i choose 2) + (j - 1)
ijToKUpper :: Int -> Int -> Int -> Int
ijToKUpper :: Int -> Int -> Int -> Int
ijToKUpper Int
n Int
i Int
j
  | Int
i forall a. Ord a => a -> a -> Bool
< Int
j = Int
i forall a. Num a => a -> a -> a
* (Int
n forall a. Num a => a -> a -> a
- Int
2) forall a. Num a => a -> a -> a
- forall a b. (RealFrac a, Integral b) => a -> b
round (Int
i Int -> Int -> Double
`choose` Int
2) forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
- Int
1
  | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"ijToKUpper: not defined for lower triangular matrix."

-- The function is a little weird because HMatrix uses Double indices for Matrix
-- Double builders.
fromListBuilderLower :: RealFrac a => [a] -> a -> a -> a
fromListBuilderLower :: forall a. RealFrac a => [a] -> a -> a -> a
fromListBuilderLower [a]
es a
i a
j
  | a
i forall a. Ord a => a -> a -> Bool
> a
j = [a]
es forall a. [a] -> Int -> a
!! Int -> Int -> Int
ijToKLower Int
iI Int
jI
  | a
i forall a. Eq a => a -> a -> Bool
== a
j = a
0.0
  | a
i forall a. Ord a => a -> a -> Bool
< a
j = [a]
es forall a. [a] -> Int -> a
!! Int -> Int -> Int
ijToKLower Int
jI Int
iI
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error
        [Char]
"Float indices could not be compared during matrix creation."
  where
    iI :: Int
iI = forall a b. (RealFrac a, Integral b) => a -> b
round a
i :: Int
    jI :: Int
jI = forall a b. (RealFrac a, Integral b) => a -> b
round a
j :: Int

-- The function is a little weird because HMatrix uses Double indices for Matrix
-- Double builders.
fromListBuilderUpper :: RealFrac a => Int -> [a] -> a -> a -> a
fromListBuilderUpper :: forall a. RealFrac a => Int -> [a] -> a -> a -> a
fromListBuilderUpper Int
n [a]
es a
i a
j
  | a
i forall a. Ord a => a -> a -> Bool
< a
j = [a]
es forall a. [a] -> Int -> a
!! Int -> Int -> Int -> Int
ijToKUpper Int
n Int
iI Int
jI
  | a
i forall a. Eq a => a -> a -> Bool
== a
j = a
0.0
  | a
i forall a. Ord a => a -> a -> Bool
> a
j = [a]
es forall a. [a] -> Int -> a
!! Int -> Int -> Int -> Int
ijToKUpper Int
n Int
jI Int
iI
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error
        [Char]
"Float indices could not be compared during matrix creation."
  where
    iI :: Int
iI = forall a b. (RealFrac a, Integral b) => a -> b
round a
i :: Int
    jI :: Int
jI = forall a b. (RealFrac a, Integral b) => a -> b
round a
j :: Int

checkEs :: RealFrac a => Int -> [a] -> [a]
checkEs :: forall a. RealFrac a => Int -> [a] -> [a]
checkEs Int
n [a]
es
  | forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
es forall a. Eq a => a -> a -> Bool
== Int
nExp = [a]
es
  | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
eStr
  where
    nExp :: Int
nExp = forall a b. (RealFrac a, Integral b) => a -> b
round (Int
n Int -> Int -> Double
`choose` Int
2)
    eStr :: [Char]
eStr =
      [[Char]] -> [Char]
unlines
        [ [Char]
"exchFromListlower: the number of exchangeabilities does not match the matrix size",
          [Char]
"matrix size: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
n,
          [Char]
"expected number of exchangeabilities: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
nExp,
          [Char]
"received number of exchangeabilities: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
es)
        ]

-- | Build exchangeability matrix from list denoting lower triangular matrix,
-- and excluding diagonal. This is how the exchangeabilities are specified in
-- PAML.
exchFromListLower :: (RealFrac a, Container Vector a) => Int -> [a] -> Matrix a
exchFromListLower :: forall a.
(RealFrac a, Container Vector a) =>
Int -> [a] -> Matrix a
exchFromListLower Int
n [a]
es = forall d f (c :: * -> *) e. Build d f c e => d -> f -> c e
build (Int
n, Int
n) (forall a. RealFrac a => [a] -> a -> a -> a
fromListBuilderLower (forall a. RealFrac a => Int -> [a] -> [a]
checkEs Int
n [a]
es))

-- | Build exchangeability matrix from list denoting upper triangular matrix,
-- and excluding diagonal.
exchFromListUpper :: (RealFrac a, Container Vector a) => Int -> [a] -> Matrix a
exchFromListUpper :: forall a.
(RealFrac a, Container Vector a) =>
Int -> [a] -> Matrix a
exchFromListUpper Int
n [a]
es = forall d f (c :: * -> *) e. Build d f c e => d -> f -> c e
build (Int
n, Int
n) (forall a. RealFrac a => Int -> [a] -> a -> a -> a
fromListBuilderUpper Int
n (forall a. RealFrac a => Int -> [a] -> [a]
checkEs Int
n [a]
es))