{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
{-# LANGUAGE UndecidableInstances,TypeApplications #-}

-- | Various instances of statistical manifolds, with a focus on exponential
-- families. In the documentation we use \(X\) to indicate a random variable
-- with the distribution being documented.
module Goal.Probability.Distributions.Gaussian
    ( -- * Univariate
      Normal
    , NormalMean
    , NormalVariance
    -- * Multivariate
    , MVNMean
    , MVNCovariance
    , MultivariateNormal
    , multivariateNormalCorrelations
    , bivariateNormalConfidenceEllipse
    , splitMultivariateNormal
    , splitMeanMultivariateNormal
    , splitNaturalMultivariateNormal
    , joinMultivariateNormal
    , joinMeanMultivariateNormal
    , joinNaturalMultivariateNormal
    -- * Linear Models
    , SimpleLinearModel
    , LinearModel
    ) where

-- Package --

import Goal.Core
import Goal.Probability.Statistical
import Goal.Probability.ExponentialFamily
import Goal.Probability.Distributions

import Goal.Geometry

import qualified Goal.Core.Vector.Storable as S
import qualified Goal.Core.Vector.Generic as G

import qualified System.Random.MWC.Distributions as R

-- Normal Distribution --

-- | The Mean of a normal distribution. When used as a distribution itself, it
-- is a Normal distribution with unit variance.
data NormalMean

-- | The variance of a normal distribution.
data NormalVariance

-- | The 'Manifold' of 'Normal' distributions. The 'Source' coordinates are the
-- mean and the variance.
type Normal = LocationShape NormalMean NormalVariance

-- | The Mean of a normal distribution. When used as a distribution itself, it
-- is a Normal distribution with unit variance.
data MVNMean (n :: Nat)

-- | The variance of a normal distribution.
data MVNCovariance (n :: Nat)

-- | Linear models are linear functions with additive Guassian noise.
type LinearModel n k = Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)

-- | Linear models are linear functions with additive Guassian noise.
type SimpleLinearModel = Affine Tensor NormalMean Normal NormalMean

-- Multivariate Normal --

-- | The 'Manifold' of 'MultivariateNormal' distributions. The 'Source'
-- coordinates are the (vector) mean and the covariance matrix. For the
-- coordinates of a multivariate normal distribution, the elements of the mean
-- come first, and then the elements of the covariance matrix in row major
-- order.
--
-- Note that we only store the lower triangular elements of the covariance
-- matrix, to better reflect the true dimension of a MultivariateNormal
-- Manifold. In short, be careful when using 'join' and 'split' to access the
-- values of the Covariance matrix, and consider using the specific instances
-- for MVNs.
type MultivariateNormal (n :: Nat) = LocationShape (MVNMean n) (MVNCovariance n)

-- | Split a MultivariateNormal into its Means and Covariance matrix.
splitMultivariateNormal
    :: KnownNat n
    => Source # MultivariateNormal n
    -> (S.Vector n Double, S.Matrix n n Double)
splitMultivariateNormal :: (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
mvn =
    let (Point Source (MVNMean n)
mu,Point Source (MVNCovariance n)
cvr) = (Source # MultivariateNormal n)
-> (Source # First (MultivariateNormal n),
    Source # Second (MultivariateNormal n))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Source # MultivariateNormal n
mvn
     in (Point Source (MVNMean n) -> Vector (Dimension (MVNMean n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Source (MVNMean n)
mu, Vector (Triangular n) Double -> Matrix n n Double
forall (n :: Nat) x.
(Storable x, KnownNat n) =>
Vector (Triangular n) x -> Matrix n n x
S.fromLowerTriangular (Vector (Triangular n) Double -> Matrix n n Double)
-> Vector (Triangular n) Double -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ Point Source (MVNCovariance n)
-> Vector (Dimension (MVNCovariance n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Source (MVNCovariance n)
cvr)

-- | Join a covariance matrix into a MultivariateNormal.
joinMultivariateNormal
    :: KnownNat n
    => S.Vector n Double
    -> S.Matrix n n Double
    -> Source # MultivariateNormal n
joinMultivariateNormal :: Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
joinMultivariateNormal Vector n Double
mus Matrix n n Double
sgma =
    (Source # First (MultivariateNormal n))
-> (Source # Second (MultivariateNormal n))
-> Source # MultivariateNormal n
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join (Vector (Dimension (MVNMean n)) Double -> Point Source (MVNMean n)
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector n Double
Vector (Dimension (MVNMean n)) Double
mus) (Vector (Dimension (MVNCovariance n)) Double
-> Point Source (MVNCovariance n)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (MVNCovariance n)) Double
 -> Point Source (MVNCovariance n))
-> Vector (Dimension (MVNCovariance n)) Double
-> Point Source (MVNCovariance n)
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector (Triangular n) Double
forall (n :: Nat) x.
(Storable x, Element x, KnownNat n) =>
Matrix n n x -> Vector (Triangular n) x
S.lowerTriangular Matrix n n Double
sgma)

-- | Split a MultivariateNormal into its Means and Covariance matrix.
splitMeanMultivariateNormal
    :: KnownNat n
    => Mean # MultivariateNormal n
    -> (S.Vector n Double, S.Matrix n n Double)
splitMeanMultivariateNormal :: (Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMeanMultivariateNormal Mean # MultivariateNormal n
mvn =
    let (Point Mean (MVNMean n)
mu,Point Mean (MVNCovariance n)
cvr) = (Mean # MultivariateNormal n)
-> (Mean # First (MultivariateNormal n),
    Mean # Second (MultivariateNormal n))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Mean # MultivariateNormal n
mvn
     in (Point Mean (MVNMean n) -> Vector (Dimension (MVNMean n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Mean (MVNMean n)
mu, Vector (Triangular n) Double -> Matrix n n Double
forall (n :: Nat) x.
(Storable x, KnownNat n) =>
Vector (Triangular n) x -> Matrix n n x
S.fromLowerTriangular (Vector (Triangular n) Double -> Matrix n n Double)
-> Vector (Triangular n) Double -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ Point Mean (MVNCovariance n)
-> Vector (Dimension (MVNCovariance n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Mean (MVNCovariance n)
cvr)

-- | Join a covariance matrix into a MultivariateNormal.
joinMeanMultivariateNormal
    :: KnownNat n
    => S.Vector n Double
    -> S.Matrix n n Double
    -> Mean # MultivariateNormal n
joinMeanMultivariateNormal :: Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
joinMeanMultivariateNormal Vector n Double
mus Matrix n n Double
sgma =
    (Mean # First (MultivariateNormal n))
-> (Mean # Second (MultivariateNormal n))
-> Mean # MultivariateNormal n
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join (Vector (Dimension (MVNMean n)) Double -> Point Mean (MVNMean n)
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector n Double
Vector (Dimension (MVNMean n)) Double
mus) (Vector (Dimension (MVNCovariance n)) Double
-> Point Mean (MVNCovariance n)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (MVNCovariance n)) Double
 -> Point Mean (MVNCovariance n))
-> Vector (Dimension (MVNCovariance n)) Double
-> Point Mean (MVNCovariance n)
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector (Triangular n) Double
forall (n :: Nat) x.
(Storable x, Element x, KnownNat n) =>
Matrix n n x -> Vector (Triangular n) x
S.lowerTriangular Matrix n n Double
sgma)

-- | Split a MultivariateNormal into the precision weighted means and (-0.5*)
-- Precision matrix. Note that this performs an easy to miss computation for
-- converting the natural parameters in our reduced representation of MVNs into
-- the full precision matrix.
splitNaturalMultivariateNormal
    :: KnownNat n
    => Natural # MultivariateNormal n
    -> (S.Vector n Double, S.Matrix n n Double)
splitNaturalMultivariateNormal :: (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
np =
    let (Point Natural (MVNMean n)
nmu,Point Natural (MVNCovariance n)
cvrs) = (Natural # MultivariateNormal n)
-> (Natural # First (MultivariateNormal n),
    Natural # Second (MultivariateNormal n))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # MultivariateNormal n
np
        nmu0 :: Vector (Dimension (MVNMean n)) Double
nmu0 = Point Natural (MVNMean n) -> Vector (Dimension (MVNMean n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Natural (MVNMean n)
nmu
        nsgma0' :: Matrix n n Double
nsgma0' = (Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Fractional a => a -> a -> a
/Matrix n n Double
2) (Matrix n n Double -> Matrix n n Double)
-> (Vector (Div (n * (n + 1)) 2) Double -> Matrix n n Double)
-> Vector (Div (n * (n + 1)) 2) Double
-> Matrix n n Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Div (n * (n + 1)) 2) Double -> Matrix n n Double
forall (n :: Nat) x.
(Storable x, KnownNat n) =>
Vector (Triangular n) x -> Matrix n n x
S.fromLowerTriangular (Vector (Div (n * (n + 1)) 2) Double -> Matrix n n Double)
-> Vector (Div (n * (n + 1)) 2) Double -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ Point Natural (MVNCovariance n)
-> Vector (Dimension (MVNCovariance n)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Natural (MVNCovariance n)
cvrs
        nsgma0 :: Matrix n n Double
nsgma0 = Matrix n n Double
nsgma0' Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
+ Vector n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Vector n x -> Matrix n n x
S.diagonalMatrix (Matrix n n Double -> Vector n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Vector n x
S.takeDiagonal Matrix n n Double
nsgma0')
     in (Vector n Double
Vector (Dimension (MVNMean n)) Double
nmu0, Matrix n n Double
nsgma0)

-- | Joins a MultivariateNormal out of the precision weighted means and (-0.5)
-- Precision matrix. Note that this performs an easy to miss computation for
-- converting the full precision Matrix into the reduced, EF representation we use here.
joinNaturalMultivariateNormal
    :: KnownNat n
    => S.Vector n Double
    -> S.Matrix n n Double
    -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal :: Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector n Double
nmu0 Matrix n n Double
nsgma0 =
    let nmu :: Point Natural (MVNMean n)
nmu = Vector (Dimension (MVNMean n)) Double -> Point Natural (MVNMean n)
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector n Double
Vector (Dimension (MVNMean n)) Double
nmu0
        diag :: Matrix n n Double
diag = Vector n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Vector n x -> Matrix n n x
S.diagonalMatrix (Vector n Double -> Matrix n n Double)
-> Vector n Double -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Vector n x
S.takeDiagonal Matrix n n Double
nsgma0
     in (Natural # First (MultivariateNormal n))
-> (Natural # Second (MultivariateNormal n))
-> Natural # MultivariateNormal n
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # First (MultivariateNormal n)
Point Natural (MVNMean n)
nmu (Point Natural (MVNCovariance n) -> Natural # MultivariateNormal n)
-> (Matrix n n Double -> Point Natural (MVNCovariance n))
-> Matrix n n Double
-> Natural # MultivariateNormal n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Vector (Triangular n) Double
-> Point Natural (MVNCovariance n)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector Vector (Triangular n) Double
 -> Point Natural (MVNCovariance n))
-> (Matrix n n Double -> Vector Vector (Triangular n) Double)
-> Matrix n n Double
-> Point Natural (MVNCovariance n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n n Double -> Vector Vector (Triangular n) Double
forall (n :: Nat) x.
(Storable x, Element x, KnownNat n) =>
Matrix n n x -> Vector (Triangular n) x
S.lowerTriangular (Matrix n n Double -> Natural # MultivariateNormal n)
-> Matrix n n Double -> Natural # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Matrix n n Double
2Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
*Matrix n n Double
nsgma0 Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
- Matrix n n Double
diag

-- | Confidence elipses for bivariate normal distributions.
bivariateNormalConfidenceEllipse
    :: Int
    -> Double
    -> Source # MultivariateNormal 2
    -> [(Double,Double)]
bivariateNormalConfidenceEllipse :: Int
-> Double -> (Source # MultivariateNormal 2) -> [(Double, Double)]
bivariateNormalConfidenceEllipse Int
nstps Double
prcnt Source # MultivariateNormal 2
nrm =
    let (Vector 2 Double
mu,Matrix 2 2 Double
cvr) = (Source # MultivariateNormal 2)
-> (Vector 2 Double, Matrix 2 2 Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal 2
nrm
        chl :: Matrix 2 2 Double
chl = (Vector (2 * 2) Double -> Vector (2 * 2) Double)
-> Matrix 2 2 Double -> Matrix 2 2 Double
forall (n :: Nat) (m :: Nat) x.
(Vector (n * m) x -> Vector (n * m) x)
-> Matrix n m x -> Matrix n m x
S.withMatrix (Double -> Vector 4 Double -> Vector 4 Double
forall x (n :: Nat). Numeric x => x -> Vector n x -> Vector n x
S.scale Double
prcnt) (Matrix 2 2 Double -> Matrix 2 2 Double)
-> Matrix 2 2 Double -> Matrix 2 2 Double
forall a b. (a -> b) -> a -> b
$ Matrix 2 2 Double -> Matrix 2 2 Double
forall (n :: Nat) x.
(KnownNat n, Field x, Storable x) =>
Matrix n n x -> Matrix n n x
S.unsafeCholesky Matrix 2 2 Double
cvr
        xs :: [Double]
xs = Double -> Double -> Int -> [Double]
forall x. RealFloat x => x -> x -> Int -> [x]
range Double
0 (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi) Int
nstps
        sxs :: [Vector 2 Double]
sxs = [ (Double, Double) -> Vector 2 Double
forall a input (length :: Nat).
(Storable a, IndexedListLiterals input length a,
 KnownNat length) =>
input -> Vector length a
S.fromTuple (Double -> Double
forall a. Floating a => a -> a
cos Double
x, Double -> Double
forall a. Floating a => a -> a
sin Double
x) | Double
x <- [Double]
xs ]
     in Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 Double -> (Double, Double))
-> (Vector 2 Double -> Vector 2 Double)
-> Vector 2 Double
-> (Double, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector 2 Double
mu Vector 2 Double -> Vector 2 Double -> Vector 2 Double
forall a. Num a => a -> a -> a
+) (Vector 2 Double -> (Double, Double))
-> [Vector 2 Double] -> [(Double, Double)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix 2 2 Double -> [Vector 2 Double] -> [Vector 2 Double]
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> [Vector n x] -> [Vector m x]
S.matrixMap Matrix 2 2 Double
chl [Vector 2 Double]
sxs

-- | Computes the correlation matrix of a 'MultivariateNormal' distribution.
multivariateNormalCorrelations
    :: KnownNat k
    => Source # MultivariateNormal k
    -> S.Matrix k k Double
multivariateNormalCorrelations :: (Source # MultivariateNormal k) -> Matrix k k Double
multivariateNormalCorrelations Source # MultivariateNormal k
mnrm =
    let cvrs :: Matrix k k Double
cvrs = (Vector k Double, Matrix k k Double) -> Matrix k k Double
forall a b. (a, b) -> b
snd ((Vector k Double, Matrix k k Double) -> Matrix k k Double)
-> (Vector k Double, Matrix k k Double) -> Matrix k k Double
forall a b. (a -> b) -> a -> b
$ (Source # MultivariateNormal k)
-> (Vector k Double, Matrix k k Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal k
mnrm
        sds :: Vector k Double
sds = (Double -> Double) -> Vector k Double -> Vector k Double
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map Double -> Double
forall a. Floating a => a -> a
sqrt (Vector k Double -> Vector k Double)
-> Vector k Double -> Vector k Double
forall a b. (a -> b) -> a -> b
$ Matrix k k Double -> Vector k Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Vector n x
S.takeDiagonal Matrix k k Double
cvrs
        sdmtx :: Matrix k k Double
sdmtx = Vector k Double -> Vector k Double -> Matrix k k Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector m x -> Vector n x -> Matrix m n x
S.outerProduct Vector k Double
sds Vector k Double
sds
     in Vector Vector (k * k) Double -> Matrix k k Double
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Vector v (m * n) a -> Matrix v m n a
G.Matrix (Vector Vector (k * k) Double -> Matrix k k Double)
-> Vector Vector (k * k) Double -> Matrix k k Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Vector Vector (k * k) Double
-> Vector Vector (k * k) Double
-> Vector Vector (k * k) Double
forall a b c (n :: Nat).
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
S.zipWith Double -> Double -> Double
forall a. Fractional a => a -> a -> a
(/) (Matrix k k Double -> Vector Vector (k * k) Double
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Matrix v m n a -> Vector v (m * n) a
G.toVector Matrix k k Double
cvrs) (Matrix k k Double -> Vector Vector (k * k) Double
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Matrix v m n a -> Vector v (m * n) a
G.toVector Matrix k k Double
sdmtx)

multivariateNormalLogBaseMeasure
    :: forall n . (KnownNat n)
    => Proxy (MultivariateNormal n)
    -> S.Vector n Double
    -> Double
multivariateNormalLogBaseMeasure :: Proxy (MultivariateNormal n) -> Vector n Double -> Double
multivariateNormalLogBaseMeasure Proxy (MultivariateNormal n)
_ Vector n Double
_ =
    let n :: Int
n = Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
     in -Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi)

mvnMeanLogBaseMeasure
    :: forall n . (KnownNat n)
    => Proxy (MVNMean n)
    -> S.Vector n Double
    -> Double
mvnMeanLogBaseMeasure :: Proxy (MVNMean n) -> Vector n Double -> Double
mvnMeanLogBaseMeasure Proxy (MVNMean n)
_ Vector n Double
x =
    let n :: Int
n = Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
     in -Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
forall a. Floating a => a
pi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Vector n Double -> Vector n Double -> Double
forall x (n :: Nat). Numeric x => Vector n x -> Vector n x -> x
S.dotProduct Vector n Double
x Vector n Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2

-- | samples a multivariateNormal by way of a covariance matrix i.e. by taking
-- the square root.
sampleMultivariateNormal
    :: KnownNat n
    => Source # MultivariateNormal n
    -> Random (S.Vector n Double)
sampleMultivariateNormal :: (Source # MultivariateNormal n) -> Random (Vector n Double)
sampleMultivariateNormal Source # MultivariateNormal n
p = do
    let (Vector n Double
mus,Matrix n n Double
sgma) = (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
p
    Vector n Double
nrms <- Random Double -> Random (Vector n Double)
forall (n :: Nat) (m :: Type -> Type) a.
(KnownNat n, Storable a, Monad m) =>
m a -> m (Vector n a)
S.replicateM (Random Double -> Random (Vector n Double))
-> Random Double -> Random (Vector n Double)
forall a b. (a -> b) -> a -> b
$ (forall s. Gen s -> ST s Double) -> Random Double
forall a. (forall s. Gen s -> ST s a) -> Random a
Random (Double -> Double -> Gen s -> ST s Double
forall g (m :: Type -> Type).
StatefulGen g m =>
Double -> Double -> g -> m Double
R.normal Double
0 Double
1)
    let rtsgma :: Matrix n n Double
rtsgma = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.matrixRoot Matrix n n Double
sgma
    Vector n Double -> Random (Vector n Double)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Vector n Double -> Random (Vector n Double))
-> Vector n Double -> Random (Vector n Double)
forall a b. (a -> b) -> a -> b
$ Vector n Double
mus Vector n Double -> Vector n Double -> Vector n Double
forall a. Num a => a -> a -> a
+ Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
rtsgma Vector n Double
nrms


--- Internal ---


--- Instances ---


-- NormalMean Distribution --

instance Manifold NormalMean where
    type Dimension NormalMean = 1

instance Statistical NormalMean where
    type SamplePoint NormalMean = Double

instance ExponentialFamily NormalMean where
    sufficientStatistic :: SamplePoint NormalMean -> Mean # NormalMean
sufficientStatistic SamplePoint NormalMean
x = Double -> Mean # NormalMean
forall x c. (Dimension x ~ 1) => Double -> c # x
singleton Double
SamplePoint NormalMean
x
    logBaseMeasure :: Proxy NormalMean -> SamplePoint NormalMean -> Double
logBaseMeasure Proxy NormalMean
_ SamplePoint NormalMean
x = -Double -> Double
forall a. Floating a => a -> a
square Double
SamplePoint NormalMean
xDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
sqrt (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi)

type instance PotentialCoordinates NormalMean = Natural

instance Transition Mean Natural NormalMean where
    transition :: (Mean # NormalMean) -> Natural # NormalMean
transition = (Mean # NormalMean) -> Natural # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Transition Mean Source NormalMean where
    transition :: (Mean # NormalMean) -> Source # NormalMean
transition = (Mean # NormalMean) -> Source # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Transition Source Natural NormalMean where
    transition :: (Source # NormalMean) -> Natural # NormalMean
transition = (Source # NormalMean) -> Natural # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Transition Source Mean NormalMean where
    transition :: (Source # NormalMean) -> Mean # NormalMean
transition = (Source # NormalMean) -> Mean # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Transition Natural Mean NormalMean where
    transition :: (Natural # NormalMean) -> Mean # NormalMean
transition = (Natural # NormalMean) -> Mean # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Transition Natural Source NormalMean where
    transition :: (Natural # NormalMean) -> Source # NormalMean
transition = (Natural # NormalMean) -> Source # NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

instance Legendre NormalMean where
    potential :: (PotentialCoordinates NormalMean # NormalMean) -> Double
potential (Point Vector (Dimension NormalMean) Double
cs) =
        let tht :: Double
tht = Vector (1 + 0) Double -> Double
forall (n :: Nat) a. Storable a => Vector (1 + n) a -> a
S.head Vector (1 + 0) Double
Vector (Dimension NormalMean) Double
cs
         in Double -> Double
forall a. Floating a => a -> a
square Double
tht Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2

instance LogLikelihood Natural NormalMean Double where
    logLikelihood :: [Double] -> (Natural # NormalMean) -> Double
logLikelihood = [Double] -> (Natural # NormalMean) -> Double
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Double
exponentialFamilyLogLikelihood
    logLikelihoodDifferential :: [Double] -> (Natural # NormalMean) -> Natural #* NormalMean
logLikelihoodDifferential = [Double] -> (Natural # NormalMean) -> Natural #* NormalMean
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Mean # x
exponentialFamilyLogLikelihoodDifferential


-- Normal Shape --


instance Manifold NormalVariance where
    type Dimension NormalVariance = 1


-- Normal Distribution --

instance ExponentialFamily Normal where
    sufficientStatistic :: SamplePoint Normal -> Mean # Normal
sufficientStatistic SamplePoint Normal
x =
         Vector 2 Double -> Mean # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector 2 Double -> Mean # Normal)
-> (Double -> Vector 2 Double) -> Double -> Mean # Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
SamplePoint Normal
x (Double -> Mean # Normal) -> Double -> Mean # Normal
forall a b. (a -> b) -> a -> b
$ Double
SamplePoint Normal
xDouble -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
2
    logBaseMeasure :: Proxy Normal -> SamplePoint Normal -> Double
logBaseMeasure Proxy Normal
_ SamplePoint Normal
_ = -Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi)

type instance PotentialCoordinates Normal = Natural

instance Legendre Normal where
    potential :: (PotentialCoordinates Normal # Normal) -> Double
potential (Point Vector (Dimension Normal) Double
cs) =
        let (Double
tht0,Double
tht1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in -(Double -> Double
forall a. Floating a => a -> a
square Double
tht0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
4Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
tht1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log(-Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
tht1)

instance Transition Natural Mean Normal where
    transition :: (Natural # Normal) -> Mean # Normal
transition Natural # Normal
p =
        let (Double
tht0,Double
tht1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 Double -> (Double, Double))
-> Vector 2 Double -> (Double, Double)
forall a b. (a -> b) -> a -> b
$ (Natural # Normal) -> Vector (Dimension Normal) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Natural # Normal
p
            dv :: Double
dv = Double
tht0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
tht1
         in Vector (Dimension Normal) Double -> Mean # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension Normal) Double -> Mean # Normal)
-> Vector (Dimension Normal) Double -> Mean # Normal
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton (-Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
dv) (Double
0.25 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
square Double
dv Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
0.5Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
tht1)

instance DuallyFlat Normal where
    dualPotential :: (PotentialCoordinates Normal #* Normal) -> Double
dualPotential (Point Vector (Dimension Normal) Double
cs) =
        let (Double
eta0,Double
eta1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in -Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log(Double
eta1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
square Double
eta0) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2

instance Transition Mean Natural Normal where
    transition :: (Mean # Normal) -> Natural # Normal
transition Mean # Normal
p =
        let (Double
eta0,Double
eta1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 Double -> (Double, Double))
-> Vector 2 Double -> (Double, Double)
forall a b. (a -> b) -> a -> b
$ (Mean # Normal) -> Vector (Dimension Normal) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Mean # Normal
p
            dff :: Double
dff = Double
eta1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
square Double
eta0
         in Vector (Dimension Normal) Double -> Natural # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension Normal) Double -> Natural # Normal)
-> Vector (Dimension Normal) Double -> Natural # Normal
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton (Double
eta0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
dff) (-Double
0.5 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
dff)

instance Riemannian Natural Normal where
    metric :: (Natural # Normal) -> Natural #* Tensor Normal Normal
metric Natural # Normal
p =
        let (Double
tht0,Double
tht1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 Double -> (Double, Double))
-> Vector 2 Double -> (Double, Double)
forall a b. (a -> b) -> a -> b
$ (Natural # Normal) -> Vector (Dimension Normal) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Natural # Normal
p
            d00 :: Double
d00 = -Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
tht1)
            d01 :: Double
d01 = Double
tht0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
square Double
tht1)
            d11 :: Double
d11 = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double -> Double
forall a. Floating a => a -> a
square Double
tht1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
square Double
tht0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
tht1Double -> Int -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
3 :: Int)))
         in Vector (Dimension (Tensor Normal Normal)) Double
-> Point Mean (Tensor Normal Normal)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (Tensor Normal Normal)) Double
 -> Point Mean (Tensor Normal Normal))
-> Vector (Dimension (Tensor Normal Normal)) Double
-> Point Mean (Tensor Normal Normal)
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
d00 Double
d01 Vector 2 Double -> Vector 2 Double -> Vector (2 + 2) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
d01 Double
d11

instance Riemannian Mean Normal where
    metric :: (Mean # Normal) -> Mean #* Tensor Normal Normal
metric Mean # Normal
p =
        let (Double
eta0,Double
eta1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 Double -> (Double, Double))
-> Vector 2 Double -> (Double, Double)
forall a b. (a -> b) -> a -> b
$ (Mean # Normal) -> Vector (Dimension Normal) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Mean # Normal
p
            eta02 :: Double
eta02 = Double -> Double
forall a. Floating a => a -> a
square Double
eta0
            dff2 :: Double
dff2 = Double -> Double
forall a. Floating a => a -> a
square (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
eta1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
eta02
            d00 :: Double
d00 = (Double
dff2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eta02) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
dff2
            d01 :: Double
d01 = -Double
eta0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
dff2
            d11 :: Double
d11 = Double
0.5 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
dff2
         in Vector (Dimension (Tensor Normal Normal)) Double
-> Point Natural (Tensor Normal Normal)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (Tensor Normal Normal)) Double
 -> Point Natural (Tensor Normal Normal))
-> Vector (Dimension (Tensor Normal Normal)) Double
-> Point Natural (Tensor Normal Normal)
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
d00 Double
d01 Vector 2 Double -> Vector 2 Double -> Vector (2 + 2) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
d01 Double
d11

-- instance Riemannian Source Normal where
--     metric p =
--         let (_,vr) = S.toPair $ coordinates p
--          in Point $ S.doubleton (recip vr) 0 S.++ S.doubleton 0 (recip $ 2*square vr)

instance Transition Source Mean Normal where
    transition :: (Source # Normal) -> Mean # Normal
transition (Point Vector (Dimension Normal) Double
cs) =
        let (Double
mu,Double
vr) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in Vector 2 Double -> Mean # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector 2 Double -> Mean # Normal)
-> (Double -> Vector 2 Double) -> Double -> Mean # Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
mu (Double -> Mean # Normal) -> Double -> Mean # Normal
forall a b. (a -> b) -> a -> b
$ Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
square Double
mu

instance Transition Mean Source Normal where
    transition :: (Mean # Normal) -> Source # Normal
transition (Point Vector (Dimension Normal) Double
cs) =
        let (Double
eta0,Double
eta1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in Vector 2 Double -> Source # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector 2 Double -> Source # Normal)
-> (Double -> Vector 2 Double) -> Double -> Source # Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton Double
eta0 (Double -> Source # Normal) -> Double -> Source # Normal
forall a b. (a -> b) -> a -> b
$ Double
eta1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
square Double
eta0

instance Transition Source Natural Normal where
    transition :: (Source # Normal) -> Natural # Normal
transition (Point Vector (Dimension Normal) Double
cs) =
        let (Double
mu,Double
vr) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in Vector (Dimension Normal) Double -> Natural # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension Normal) Double -> Natural # Normal)
-> Vector (Dimension Normal) Double -> Natural # Normal
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton (Double
mu Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
vr) (Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vr)

instance Transition Natural Source Normal where
    transition :: (Natural # Normal) -> Source # Normal
transition (Point Vector (Dimension Normal) Double
cs) =
        let (Double
tht0,Double
tht1) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in Vector (Dimension Normal) Double -> Source # Normal
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension Normal) Double -> Source # Normal)
-> Vector (Dimension Normal) Double -> Source # Normal
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Vector 2 Double
forall x. Storable x => x -> x -> Vector 2 x
S.doubleton (-Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
tht0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
tht1) (Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
tht1)

instance (Transition c Source Normal) => Generative c Normal where
    samplePoint :: Point c Normal -> Random (SamplePoint Normal)
samplePoint Point c Normal
p =
        let (Point Vector (Dimension Normal) Double
cs) = Point c Normal -> Source # Normal
forall c x. Transition c Source x => (c # x) -> Source # x
toSource Point c Normal
p
            (Double
mu,Double
vr) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
         in (forall s. Gen s -> ST s Double) -> Random Double
forall a. (forall s. Gen s -> ST s a) -> Random a
Random ((forall s. Gen s -> ST s Double) -> Random Double)
-> (forall s. Gen s -> ST s Double) -> Random Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen s -> ST s Double
forall g (m :: Type -> Type).
StatefulGen g m =>
Double -> Double -> g -> m Double
R.normal Double
mu (Double -> Double
forall a. Floating a => a -> a
sqrt Double
vr)

instance AbsolutelyContinuous Source Normal where
    densities :: (Source # Normal) -> Sample Normal -> [Double]
densities (Point Vector (Dimension Normal) Double
cs) Sample Normal
xs = do
        let (Double
mu,Double
vr) = Vector 2 Double -> (Double, Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair Vector 2 Double
Vector (Dimension Normal) Double
cs
        Double
x <- [Double]
Sample Normal
xs
        Double -> [Double]
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Double -> [Double]) -> Double -> [Double]
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
vrDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
mu) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
vr))

instance AbsolutelyContinuous Mean Normal where
    densities :: (Mean # Normal) -> Sample Normal -> [Double]
densities = (Source # Normal) -> [Double] -> [Double]
forall c x.
AbsolutelyContinuous c x =>
Point c x -> Sample x -> [Double]
densities ((Source # Normal) -> [Double] -> [Double])
-> ((Mean # Normal) -> Source # Normal)
-> (Mean # Normal)
-> [Double]
-> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # Normal) -> Source # Normal
forall c x. Transition c Source x => (c # x) -> Source # x
toSource

instance AbsolutelyContinuous Natural Normal where
    logDensities :: (Natural # Normal) -> Sample Normal -> [Double]
logDensities = (Natural # Normal) -> Sample Normal -> [Double]
forall x.
(ExponentialFamily x, Legendre x,
 PotentialCoordinates x ~ Natural) =>
(Natural # x) -> Sample x -> [Double]
exponentialFamilyLogDensities

instance Transition Mean c Normal => MaximumLikelihood c Normal where
    mle :: Sample Normal -> c # Normal
mle = (Mean # Normal) -> c # Normal
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Mean # Normal) -> c # Normal)
-> ([Double] -> Mean # Normal) -> [Double] -> c # Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Double] -> Mean # Normal
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic

instance LogLikelihood Natural Normal Double where
    logLikelihood :: [Double] -> (Natural # Normal) -> Double
logLikelihood = [Double] -> (Natural # Normal) -> Double
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Double
exponentialFamilyLogLikelihood
    logLikelihoodDifferential :: [Double] -> (Natural # Normal) -> Natural #* Normal
logLikelihoodDifferential = [Double] -> (Natural # Normal) -> Natural #* Normal
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Mean # x
exponentialFamilyLogLikelihoodDifferential


-- MVNMean --

instance KnownNat n => Manifold (MVNMean n) where
    type Dimension (MVNMean n) = n

instance (KnownNat n) => Statistical (MVNMean n) where
    type SamplePoint (MVNMean n) = S.Vector n Double

instance KnownNat n => ExponentialFamily (MVNMean n) where
    sufficientStatistic :: SamplePoint (MVNMean n) -> Mean # MVNMean n
sufficientStatistic SamplePoint (MVNMean n)
x = Vector (Dimension (MVNMean n)) Double -> Mean # MVNMean n
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector (Dimension (MVNMean n)) Double
SamplePoint (MVNMean n)
x
    logBaseMeasure :: Proxy (MVNMean n) -> SamplePoint (MVNMean n) -> Double
logBaseMeasure = Proxy (MVNMean n) -> SamplePoint (MVNMean n) -> Double
forall (n :: Nat).
KnownNat n =>
Proxy (MVNMean n) -> Vector n Double -> Double
mvnMeanLogBaseMeasure

type instance PotentialCoordinates (MVNMean n) = Natural

-- MVNCovariance --

instance (KnownNat n, KnownNat (Triangular n)) => Manifold (MVNCovariance n) where
    type Dimension (MVNCovariance n) = Triangular n

-- Multivariate Normal --

instance (KnownNat n, KnownNat (Triangular n))
  => AbsolutelyContinuous Source (MultivariateNormal n) where
      densities :: Point Source (MultivariateNormal n)
-> Sample (MultivariateNormal n) -> [Double]
densities Point Source (MultivariateNormal n)
mvn Sample (MultivariateNormal n)
xs = do
          let (Vector n Double
mu,Matrix n n Double
sgma) = Point Source (MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Point Source (MultivariateNormal n)
mvn
              n :: Double
n = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy @ n)
              scl :: Double
scl = (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi)Double -> Double -> Double
forall a. Floating a => a -> a -> a
**(-Double
nDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Matrix n n Double -> Double
forall (n :: Nat) x. (KnownNat n, Field x) => Matrix n n x -> x
S.determinant Matrix n n Double
sgmaDouble -> Double -> Double
forall a. Floating a => a -> a -> a
**(-Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2)
              isgma :: Matrix n n Double
isgma = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.pseudoInverse Matrix n n Double
sgma
          Vector n Double
x <- [Vector n Double]
Sample (MultivariateNormal n)
xs
          let dff :: Vector n Double
dff = Vector n Double
x Vector n Double -> Vector n Double -> Vector n Double
forall a. Num a => a -> a -> a
- Vector n Double
mu
              expval :: Double
expval = Vector n Double -> Vector n Double -> Double
forall x (n :: Nat). Numeric x => Vector n x -> Vector n x -> x
S.dotProduct Vector n Double
dff (Vector n Double -> Double) -> Vector n Double -> Double
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
isgma Vector n Double
dff
          Double -> [Double]
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Double -> [Double]) -> Double -> [Double]
forall a b. (a -> b) -> a -> b
$ Double
scl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (-Double
expval Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2)

instance (KnownNat n, KnownNat (Triangular n), Transition c Source (MultivariateNormal n))
  => Generative c (MultivariateNormal n) where
    samplePoint :: Point c (MultivariateNormal n)
-> Random (SamplePoint (MultivariateNormal n))
samplePoint = (Source # MultivariateNormal n) -> Random (Vector n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n) -> Random (Vector n Double)
sampleMultivariateNormal ((Source # MultivariateNormal n) -> Random (Vector n Double))
-> (Point c (MultivariateNormal n)
    -> Source # MultivariateNormal n)
-> Point c (MultivariateNormal n)
-> Random (Vector n Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point c (MultivariateNormal n) -> Source # MultivariateNormal n
forall c x. Transition c Source x => (c # x) -> Source # x
toSource

instance KnownNat n => Transition Source Natural (MultivariateNormal n) where
    transition :: (Source # MultivariateNormal n) -> Natural # MultivariateNormal n
transition Source # MultivariateNormal n
p =
        let (Vector n Double
mu,Matrix n n Double
sgma) = (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
p
            invsgma :: Matrix n n Double
invsgma = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.pseudoInverse Matrix n n Double
sgma
         in Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal (Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
invsgma Vector n Double
mu) (Matrix n n Double -> Natural # MultivariateNormal n)
-> Matrix n n Double -> Natural # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ (-Matrix n n Double
0.5) Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
* Matrix n n Double
invsgma

instance KnownNat n => Transition Natural Source (MultivariateNormal n) where
    transition :: (Natural # MultivariateNormal n) -> Source # MultivariateNormal n
transition Natural # MultivariateNormal n
p =
        let (Vector n Double
nmu,Matrix n n Double
nsgma) = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
p
            insgma :: Matrix n n Double
insgma = (-Matrix n n Double
0.5) Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
* Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.pseudoInverse Matrix n n Double
nsgma
         in Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
joinMultivariateNormal (Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
insgma Vector n Double
nmu) Matrix n n Double
insgma

instance KnownNat n => LogLikelihood Natural (MultivariateNormal n) (S.Vector n Double) where
    logLikelihood :: [Vector n Double] -> (Natural # MultivariateNormal n) -> Double
logLikelihood = [Vector n Double] -> (Natural # MultivariateNormal n) -> Double
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Double
exponentialFamilyLogLikelihood
    logLikelihoodDifferential :: [Vector n Double]
-> (Natural # MultivariateNormal n)
-> Natural #* MultivariateNormal n
logLikelihoodDifferential = [Vector n Double]
-> (Natural # MultivariateNormal n)
-> Natural #* MultivariateNormal n
forall x.
LegendreExponentialFamily x =>
Sample x -> (Natural # x) -> Mean # x
exponentialFamilyLogLikelihoodDifferential


instance (KnownNat n, KnownNat (Triangular n)) => ExponentialFamily (MultivariateNormal n) where
    sufficientStatistic :: SamplePoint (MultivariateNormal n) -> Mean # MultivariateNormal n
sufficientStatistic SamplePoint (MultivariateNormal n)
xs = Vector (Dimension (MultivariateNormal n)) Double
-> Mean # MultivariateNormal n
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (MultivariateNormal n)) Double
 -> Mean # MultivariateNormal n)
-> Vector (Dimension (MultivariateNormal n)) Double
-> Mean # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Vector n Double
SamplePoint (MultivariateNormal n)
xs Vector n Double
-> Vector (Triangular n) Double -> Vector (n + Triangular n) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Matrix n n Double -> Vector (Triangular n) Double
forall (n :: Nat) x.
(Storable x, Element x, KnownNat n) =>
Matrix n n x -> Vector (Triangular n) x
S.lowerTriangular (Vector n Double -> Vector n Double -> Matrix n n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector m x -> Vector n x -> Matrix m n x
S.outerProduct Vector n Double
SamplePoint (MultivariateNormal n)
xs Vector n Double
SamplePoint (MultivariateNormal n)
xs)
    averageSufficientStatistic :: Sample (MultivariateNormal n) -> Mean # MultivariateNormal n
averageSufficientStatistic Sample (MultivariateNormal n)
xs = Vector (Dimension (MultivariateNormal n)) Double
-> Mean # MultivariateNormal n
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (MultivariateNormal n)) Double
 -> Mean # MultivariateNormal n)
-> Vector (Dimension (MultivariateNormal n)) Double
-> Mean # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ [Vector n Double] -> Vector n Double
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Vector n Double]
Sample (MultivariateNormal n)
xs Vector n Double
-> Vector (Triangular n) Double -> Vector (n + Triangular n) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Matrix n n Double -> Vector (Triangular n) Double
forall (n :: Nat) x.
(Storable x, Element x, KnownNat n) =>
Matrix n n x -> Vector (Triangular n) x
S.lowerTriangular ( [(Vector n Double, Vector n Double)] -> Matrix n n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Fractional x, Numeric x) =>
[(Vector m x, Vector n x)] -> Matrix m n x
S.averageOuterProduct ([(Vector n Double, Vector n Double)] -> Matrix n n Double)
-> [(Vector n Double, Vector n Double)] -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ [Vector n Double]
-> [Vector n Double] -> [(Vector n Double, Vector n Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Vector n Double]
Sample (MultivariateNormal n)
xs [Vector n Double]
Sample (MultivariateNormal n)
xs )
    logBaseMeasure :: Proxy (MultivariateNormal n)
-> SamplePoint (MultivariateNormal n) -> Double
logBaseMeasure = Proxy (MultivariateNormal n)
-> SamplePoint (MultivariateNormal n) -> Double
forall (n :: Nat).
KnownNat n =>
Proxy (MultivariateNormal n) -> Vector n Double -> Double
multivariateNormalLogBaseMeasure

type instance PotentialCoordinates (MultivariateNormal n) = Natural

instance (KnownNat n, KnownNat (Triangular n)) => Legendre (MultivariateNormal n) where
    potential :: (PotentialCoordinates (MultivariateNormal n)
 # MultivariateNormal n)
-> Double
potential PotentialCoordinates (MultivariateNormal n) # MultivariateNormal n
p =
        let (Vector n Double
nmu,Matrix n n Double
nsgma) = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal PotentialCoordinates (MultivariateNormal n) # MultivariateNormal n
Natural # MultivariateNormal n
p
            insgma :: Matrix n n Double
insgma = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.pseudoInverse Matrix n n Double
nsgma
         in -Double
0.25 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Vector n Double -> Vector n Double -> Double
forall x (n :: Nat). Numeric x => Vector n x -> Vector n x -> x
S.dotProduct Vector n Double
nmu (Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
insgma Vector n Double
nmu)
             Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double)
-> (Matrix n n Double -> Double) -> Matrix n n Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n n Double -> Double
forall (n :: Nat) x. (KnownNat n, Field x) => Matrix n n x -> x
S.determinant (Matrix n n Double -> Double)
-> (Matrix n n Double -> Matrix n n Double)
-> Matrix n n Double
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a
negate (Matrix n n Double -> Double) -> Matrix n n Double -> Double
forall a b. (a -> b) -> a -> b
$ Matrix n n Double
2 Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
* Matrix n n Double
nsgma)

instance (KnownNat n, KnownNat (Triangular n)) => Transition Natural Mean (MultivariateNormal n) where
    transition :: (Natural # MultivariateNormal n) -> Mean # MultivariateNormal n
transition = (Source # MultivariateNormal n) -> Mean # MultivariateNormal n
forall c x. Transition c Mean x => (c # x) -> Mean # x
toMean ((Source # MultivariateNormal n) -> Mean # MultivariateNormal n)
-> ((Natural # MultivariateNormal n)
    -> Source # MultivariateNormal n)
-> (Natural # MultivariateNormal n)
-> Mean # MultivariateNormal n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # MultivariateNormal n) -> Source # MultivariateNormal n
forall c x. Transition c Source x => (c # x) -> Source # x
toSource

instance (KnownNat n, KnownNat (Triangular n)) => DuallyFlat (MultivariateNormal n) where
    dualPotential :: (PotentialCoordinates (MultivariateNormal n)
 #* MultivariateNormal n)
-> Double
dualPotential PotentialCoordinates (MultivariateNormal n) #* MultivariateNormal n
p =
        let sgma :: Matrix n n Double
sgma = (Vector n Double, Matrix n n Double) -> Matrix n n Double
forall a b. (a, b) -> b
snd ((Vector n Double, Matrix n n Double) -> Matrix n n Double)
-> ((Source # MultivariateNormal n)
    -> (Vector n Double, Matrix n n Double))
-> (Source # MultivariateNormal n)
-> Matrix n n Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal ((Source # MultivariateNormal n) -> Matrix n n Double)
-> (Source # MultivariateNormal n) -> Matrix n n Double
forall a b. (a -> b) -> a -> b
$ (Mean # MultivariateNormal n) -> Source # MultivariateNormal n
forall c x. Transition c Source x => (c # x) -> Source # x
toSource PotentialCoordinates (MultivariateNormal n) #* MultivariateNormal n
Mean # MultivariateNormal n
p
            n :: Int
n = Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy @ n)
            lndet :: Double
lndet = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
exp Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log (Matrix n n Double -> Double
forall (n :: Nat) x. (KnownNat n, Field x) => Matrix n n x -> x
S.determinant Matrix n n Double
sgma)
         in -Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
lndet

instance (KnownNat n, KnownNat (Triangular n)) => Transition Mean Natural (MultivariateNormal n) where
    transition :: (Mean # MultivariateNormal n) -> Natural # MultivariateNormal n
transition = (Source # MultivariateNormal n) -> Natural # MultivariateNormal n
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural ((Source # MultivariateNormal n) -> Natural # MultivariateNormal n)
-> ((Mean # MultivariateNormal n) -> Source # MultivariateNormal n)
-> (Mean # MultivariateNormal n)
-> Natural # MultivariateNormal n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # MultivariateNormal n) -> Source # MultivariateNormal n
forall c x. Transition c Source x => (c # x) -> Source # x
toSource

instance KnownNat n => Transition Source Mean (MultivariateNormal n) where
    transition :: (Source # MultivariateNormal n) -> Mean # MultivariateNormal n
transition Source # MultivariateNormal n
p =
        let (Vector n Double
mu,Matrix n n Double
sgma) = (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
p
         in Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
joinMeanMultivariateNormal Vector n Double
mu (Matrix n n Double -> Mean # MultivariateNormal n)
-> Matrix n n Double -> Mean # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Matrix n n Double
sgma Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
+ Vector n Double -> Vector n Double -> Matrix n n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector m x -> Vector n x -> Matrix m n x
S.outerProduct Vector n Double
mu Vector n Double
mu

instance KnownNat n => Transition Mean Source (MultivariateNormal n) where
    transition :: (Mean # MultivariateNormal n) -> Source # MultivariateNormal n
transition Mean # MultivariateNormal n
p =
        let (Vector n Double
mu,Matrix n n Double
scnds) = (Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMeanMultivariateNormal Mean # MultivariateNormal n
p
         in Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
joinMultivariateNormal Vector n Double
mu (Matrix n n Double -> Source # MultivariateNormal n)
-> Matrix n n Double -> Source # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Matrix n n Double
scnds Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
- Vector n Double -> Vector n Double -> Matrix n n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector m x -> Vector n x -> Matrix m n x
S.outerProduct Vector n Double
mu Vector n Double
mu

instance (KnownNat n, KnownNat (Triangular n)) => AbsolutelyContinuous Natural (MultivariateNormal n) where
    logDensities :: Point Natural (MultivariateNormal n)
-> Sample (MultivariateNormal n) -> [Double]
logDensities = Point Natural (MultivariateNormal n)
-> Sample (MultivariateNormal n) -> [Double]
forall x.
(ExponentialFamily x, Legendre x,
 PotentialCoordinates x ~ Natural) =>
(Natural # x) -> Sample x -> [Double]
exponentialFamilyLogDensities

instance (KnownNat n, Transition Mean c (MultivariateNormal n))
  => MaximumLikelihood c (MultivariateNormal n) where
    mle :: Sample (MultivariateNormal n) -> c # MultivariateNormal n
mle = (Mean # MultivariateNormal n) -> c # MultivariateNormal n
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Mean # MultivariateNormal n) -> c # MultivariateNormal n)
-> ([Vector n Double] -> Mean # MultivariateNormal n)
-> [Vector n Double]
-> c # MultivariateNormal n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Vector n Double] -> Mean # MultivariateNormal n
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic

--instance KnownNat n => MaximumLikelihood Source (MultivariateNormal n) where
--    mle _ xss =
--        let n = fromIntegral $ length xss
--            mus = recip (fromIntegral n) * sum xss
--            sgma = recip (fromIntegral $ n - 1)
--                * sum (map (\xs -> let xs' = xs - mus in M.outer xs' xs') xss)
--        in  joinMultivariateNormal mus sgma

-- Linear Models

instance ( KnownNat n, KnownNat k)
  => Transition Natural Source (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)) where
    transition :: (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> Source
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
transition Natural
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
nfa =
        let (Natural # MultivariateNormal n
mvn,Natural # Tensor (MVNMean n) (MVNMean k)
nmtx) = (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Natural
    # First
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)),
    Natural
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
nfa
            (Vector n Double
nmu,Matrix n n Double
nsg) = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
mvn
            invsg :: Matrix n n Double
invsg = -Matrix n n Double
2 Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
* Matrix n n Double
nsg
            ssg :: Matrix n n Double
ssg = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.inverse Matrix n n Double
invsg
            smu :: Vector n Double
smu = Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
ssg Vector n Double
nmu
            smvn :: Source # MultivariateNormal n
smvn = Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
joinMultivariateNormal Vector n Double
smu Matrix n n Double
ssg
            smtx :: Matrix n k Double
smtx = Matrix n n Double -> Matrix n k Double -> Matrix n k Double
forall (m :: Nat) (n :: Nat) (o :: Nat) x.
(KnownNat m, KnownNat n, KnownNat o, Numeric x) =>
Matrix m n x -> Matrix n o x -> Matrix m o x
S.matrixMatrixMultiply Matrix n n Double
ssg (Matrix n k Double -> Matrix n k Double)
-> Matrix n k Double -> Matrix n k Double
forall a b. (a -> b) -> a -> b
$ (Natural # Tensor (MVNMean n) (MVNMean k))
-> Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Natural # Tensor (MVNMean n) (MVNMean k)
nmtx
         in (Source
 # First
     (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> (Source
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> Source
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Source
# First
    (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
Source # MultivariateNormal n
smvn ((Source
  # Second
      (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
 -> Source
    # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Source
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> Source
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
forall a b. (a -> b) -> a -> b
$ Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
-> Source # Tensor (MVNMean n) (MVNMean k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
smtx

instance ( KnownNat n, KnownNat k)
  => Transition Source Natural (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)) where
    transition :: (Source
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> Natural
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
transition Source
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
lmdl =
        let (Source # MultivariateNormal n
smvn,Source # Tensor (MVNMean n) (MVNMean k)
smtx) = (Source
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Source
    # First
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)),
    Source
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Source
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
lmdl
            (Vector n Double
smu,Matrix n n Double
ssg) = (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
smvn
            invsg :: Matrix n n Double
invsg = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.inverse Matrix n n Double
ssg
            nmu :: Vector n Double
nmu = Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
S.matrixVectorMultiply Matrix n n Double
invsg Vector n Double
smu
            nsg :: Matrix n n Double
nsg = -Matrix n n Double
0.5 Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
* Matrix n n Double
invsg
            nmtx :: Matrix n k Double
nmtx = Matrix n n Double -> Matrix n k Double -> Matrix n k Double
forall (m :: Nat) (n :: Nat) (o :: Nat) x.
(KnownNat m, KnownNat n, KnownNat o, Numeric x) =>
Matrix m n x -> Matrix n o x -> Matrix m o x
S.matrixMatrixMultiply Matrix n n Double
invsg (Matrix n k Double -> Matrix n k Double)
-> Matrix n k Double -> Matrix n k Double
forall a b. (a -> b) -> a -> b
$ (Source # Tensor (MVNMean n) (MVNMean k))
-> Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Source # Tensor (MVNMean n) (MVNMean k)
smtx
            nmvn :: Natural # MultivariateNormal n
nmvn = Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector n Double
nmu Matrix n n Double
nsg
         in (Natural
 # First
     (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> (Natural
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> Natural
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural
# First
    (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
Natural # MultivariateNormal n
nmvn ((Natural
  # Second
      (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
 -> Natural
    # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Natural
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
-> Natural
   # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
forall a b. (a -> b) -> a -> b
$ Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
-> Natural # Tensor (MVNMean n) (MVNMean k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
nmtx

instance ( KnownNat n, KnownNat k)
  => Transition Natural Source (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)) where
      transition :: (Natural
 # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
-> Source
   # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
transition Natural
# Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
nfa =
          let (Natural # Replicated n Normal
nnrms,Natural # Tensor (MVNMean n) (MVNMean k)
nmtx) = (Natural
 # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
-> (Natural
    # First
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)),
    Natural
    # Second
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural
# Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
nfa
              (Point Natural (Replicated n NormalMean)
nmu,Point Natural (Replicated n NormalVariance)
nsg) = (Natural # Replicated n Normal)
-> (Natural # Replicated n (First Normal),
    Natural # Replicated n (Second Normal))
forall (k :: Nat) x c.
(KnownNat k, Product x) =>
(c # Replicated k x)
-> (c # Replicated k (First x), c # Replicated k (Second x))
splitReplicatedProduct Natural # Replicated n Normal
nnrms
              nmvn :: Natural # MultivariateNormal n
nmvn = Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal (Point Natural (Replicated n NormalMean)
-> Vector (Dimension (Replicated n NormalMean)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Natural (Replicated n NormalMean)
nmu) (Matrix n n Double -> Natural # MultivariateNormal n)
-> Matrix n n Double -> Natural # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Vector n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Vector n x -> Matrix n n x
S.diagonalMatrix (Point Natural (Replicated n NormalVariance)
-> Vector (Dimension (Replicated n NormalVariance)) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Point Natural (Replicated n NormalVariance)
nsg)
              nlm :: Natural # LinearModel n k
              nlm :: Natural # LinearModel n k
nlm = (Natural # First (LinearModel n k))
-> (Natural # Second (LinearModel n k))
-> Natural # LinearModel n k
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # First (LinearModel n k)
Natural # MultivariateNormal n
nmvn Natural # Second (LinearModel n k)
Natural # Tensor (MVNMean n) (MVNMean k)
nmtx
              (Source # MultivariateNormal n
smvn,Source # Tensor (MVNMean n) (MVNMean k)
smtx) = (Source # LinearModel n k)
-> (Source # First (LinearModel n k),
    Source # Second (LinearModel n k))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Source # LinearModel n k)
 -> (Source # First (LinearModel n k),
     Source # Second (LinearModel n k)))
-> (Source # LinearModel n k)
-> (Source # First (LinearModel n k),
    Source # Second (LinearModel n k))
forall a b. (a -> b) -> a -> b
$ (Natural # LinearModel n k) -> Source # LinearModel n k
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # LinearModel n k
nlm
              (Vector n Double
smu,Matrix n n Double
ssg) = (Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Source # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMultivariateNormal Source # MultivariateNormal n
smvn
              snrms :: Source # Replicated n Normal
snrms = (Source # Replicated n (First Normal))
-> (Source # Replicated n (Second Normal))
-> Source # Replicated n Normal
forall (k :: Nat) x c.
(KnownNat k, Product x) =>
(c # Replicated k (First x))
-> (c # Replicated k (Second x)) -> c # Replicated k x
joinReplicatedProduct (Vector (Dimension (Replicated n NormalMean)) Double
-> Point Source (Replicated n NormalMean)
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector n Double
Vector (Dimension (Replicated n NormalMean)) Double
smu) (Vector (Dimension (Replicated n NormalVariance)) Double
-> Point Source (Replicated n NormalVariance)
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (Replicated n NormalVariance)) Double
 -> Point Source (Replicated n NormalVariance))
-> Vector (Dimension (Replicated n NormalVariance)) Double
-> Point Source (Replicated n NormalVariance)
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Vector n x
S.takeDiagonal Matrix n n Double
ssg)
           in (Source
 # First
     (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
-> (Source
    # Second
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
-> Source
   # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Source
# First
    (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
Source # Replicated n Normal
snrms Source
# Second
    (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
Source # Tensor (MVNMean n) (MVNMean k)
smtx

instance ( KnownNat n, KnownNat k)
  => Transition Source Natural (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)) where
      transition :: (Source
 # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
-> Natural
   # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
transition Source
# Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
sfa =
          let (Source # Replicated n Normal
snrms,Source # Tensor (MVNMean n) (MVNMean k)
smtx) = (Source
 # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
-> (Source
    # First
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)),
    Source
    # Second
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Source
# Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
sfa
              (Vector n Double
smu,Vector n Double
ssg) = Vector 2 (Vector n Double) -> (Vector n Double, Vector n Double)
forall x. Storable x => Vector 2 x -> (x, x)
S.toPair (Vector 2 (Vector n Double) -> (Vector n Double, Vector n Double))
-> (Vector n (Source # Normal) -> Vector 2 (Vector n Double))
-> Vector n (Source # Normal)
-> (Vector n Double, Vector n Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n 2 Double -> Vector 2 (Vector n Double)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n (Vector m x)
S.toColumns (Matrix n 2 Double -> Vector 2 (Vector n Double))
-> (Vector n (Source # Normal) -> Matrix n 2 Double)
-> Vector n (Source # Normal)
-> Vector 2 (Vector n Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n (Vector 2 Double) -> Matrix n 2 Double
forall (n :: Nat) x (m :: Nat).
(KnownNat n, Storable x) =>
Vector m (Vector n x) -> Matrix m n x
S.fromRows (Vector n (Vector 2 Double) -> Matrix n 2 Double)
-> (Vector n (Source # Normal) -> Vector n (Vector 2 Double))
-> Vector n (Source # Normal)
-> Matrix n 2 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Source # Normal) -> Vector 2 Double)
-> Vector n (Source # Normal) -> Vector n (Vector 2 Double)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Source # Normal) -> Vector 2 Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates (Vector n (Source # Normal) -> (Vector n Double, Vector n Double))
-> Vector n (Source # Normal) -> (Vector n Double, Vector n Double)
forall a b. (a -> b) -> a -> b
$ (Source # Replicated n Normal) -> Vector n (Source # Normal)
forall (k :: Nat) x c.
(KnownNat k, Manifold x) =>
(c # Replicated k x) -> Vector k (c # x)
splitReplicated Source # Replicated n Normal
snrms
              smvn :: Source # MultivariateNormal n
smvn = Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Source # MultivariateNormal n
joinMultivariateNormal Vector n Double
smu (Matrix n n Double -> Source # MultivariateNormal n)
-> Matrix n n Double -> Source # MultivariateNormal n
forall a b. (a -> b) -> a -> b
$ Vector n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Vector n x -> Matrix n n x
S.diagonalMatrix Vector n Double
ssg
              slm :: Source # LinearModel n k
              slm :: Source # LinearModel n k
slm = (Source # First (LinearModel n k))
-> (Source # Second (LinearModel n k)) -> Source # LinearModel n k
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Source # First (LinearModel n k)
Source # MultivariateNormal n
smvn Source # Second (LinearModel n k)
Source # Tensor (MVNMean n) (MVNMean k)
smtx
              (Natural # MultivariateNormal n
nmvn,Natural # Tensor (MVNMean n) (MVNMean k)
nmtx) = (Natural # LinearModel n k)
-> (Natural # First (LinearModel n k),
    Natural # Second (LinearModel n k))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # LinearModel n k)
 -> (Natural # First (LinearModel n k),
     Natural # Second (LinearModel n k)))
-> (Natural # LinearModel n k)
-> (Natural # First (LinearModel n k),
    Natural # Second (LinearModel n k))
forall a b. (a -> b) -> a -> b
$ (Source # LinearModel n k) -> Natural # LinearModel n k
forall c d x. Transition c d x => (c # x) -> d # x
transition Source # LinearModel n k
slm
              (Vector n Double
nmu,Matrix n n Double
nsg) = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
nmvn
              nnrms :: Natural # Replicated n Normal
nnrms = Vector n (Natural # Normal) -> Natural # Replicated n Normal
forall (k :: Nat) x c.
(KnownNat k, Manifold x) =>
Vector k (c # x) -> c # Replicated k x
joinReplicated (Vector n (Natural # Normal) -> Natural # Replicated n Normal)
-> Vector n (Natural # Normal) -> Natural # Replicated n Normal
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Natural # Normal)
-> Vector n Double
-> Vector n Double
-> Vector n (Natural # Normal)
forall a b c (n :: Nat).
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
S.zipWith (((Double, Double) -> Natural # Normal)
-> Double -> Double -> Natural # Normal
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (Double, Double) -> Natural # Normal
forall ds x c.
(IndexedListLiterals ds (Dimension x) Double,
 KnownNat (Dimension x)) =>
ds -> c # x
fromTuple) Vector n Double
nmu (Vector n Double -> Vector n (Natural # Normal))
-> Vector n Double -> Vector n (Natural # Normal)
forall a b. (a -> b) -> a -> b
$ Matrix n n Double -> Vector n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Vector n x
S.takeDiagonal Matrix n n Double
nsg
           in (Natural
 # First
     (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
-> (Natural
    # Second
        (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)))
-> Natural
   # Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural
# First
    (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
Natural # Replicated n Normal
nnrms Natural
# Second
    (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k))
Natural # Tensor (MVNMean n) (MVNMean k)
nmtx

instance Transition Natural Source (Affine Tensor NormalMean Normal NormalMean) where
      transition :: (Natural # Affine Tensor NormalMean Normal NormalMean)
-> Source # Affine Tensor NormalMean Normal NormalMean
transition Natural # Affine Tensor NormalMean Normal NormalMean
nfa =
          let nfa' :: Natural # LinearModel 1 1
              nfa' :: Natural # LinearModel 1 1
nfa' = (Natural # Affine Tensor NormalMean Normal NormalMean)
-> Natural # LinearModel 1 1
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Natural # Affine Tensor NormalMean Normal NormalMean
nfa
              sfa' :: Source # LinearModel 1 1
              sfa' :: Source # LinearModel 1 1
sfa' = (Natural # LinearModel 1 1) -> Source # LinearModel 1 1
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # LinearModel 1 1
nfa'
           in (Source # LinearModel 1 1)
-> Source # Affine Tensor NormalMean Normal NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Source # LinearModel 1 1
sfa'

instance Transition Source Natural (Affine Tensor NormalMean Normal NormalMean) where
      transition :: (Source # Affine Tensor NormalMean Normal NormalMean)
-> Natural # Affine Tensor NormalMean Normal NormalMean
transition Source # Affine Tensor NormalMean Normal NormalMean
sfa =
          let sfa' :: Source # LinearModel 1 1
              sfa' :: Source # LinearModel 1 1
sfa' = (Source # Affine Tensor NormalMean Normal NormalMean)
-> Source # LinearModel 1 1
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Source # Affine Tensor NormalMean Normal NormalMean
sfa
              nfa' :: Natural # LinearModel 1 1
              nfa' :: Natural # LinearModel 1 1
nfa' = (Source # LinearModel 1 1) -> Natural # LinearModel 1 1
forall c d x. Transition c d x => (c # x) -> d # x
transition Source # LinearModel 1 1
sfa'
           in (Natural # LinearModel 1 1)
-> Natural # Affine Tensor NormalMean Normal NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Natural # LinearModel 1 1
nfa'



--instance ( KnownNat n, KnownNat k)
--  => Transition Natural Source (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)) where
--    transition nfa =
--        let (nnrms,nmtx) = split nfa
--            (nmu,nsg) = S.toPair . S.toColumns . S.fromRows . S.map coordinates
--                $ splitReplicated nnrms
--            invsg = -2 * nsg
--            ssg = recip invsg
--            smu = nmu / invsg
--            snrms = joinReplicated $ S.zipWith (curry fromTuple) smu ssg
--            smtx = S.matrixMatrixMultiply (S.diagonalMatrix ssg) $ toMatrix nmtx
--         in join snrms $ fromMatrix smtx

--instance ( KnownNat n, KnownNat k)
--  => Transition Source Natural (Affine Tensor (MVNMean n) (Replicated n Normal) (MVNMean k)) where
--    transition sfa =
--        let (snrms,smtx) = split sfa
--            (smu,ssg) = S.toPair . S.toColumns . S.fromRows . S.map coordinates
--                $ splitReplicated snrms
--            invsg = recip ssg
--            nmu = invsg * smu
--            nsg = -0.5 * invsg
--            nmtx = S.matrixMatrixMultiply (S.diagonalMatrix invsg) $ toMatrix smtx
--            nnrms = joinReplicated $ S.zipWith (curry fromTuple) nmu nsg
--         in join nnrms $ fromMatrix nmtx