{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
{-# LANGUAGE
    TypeApplications,
    UndecidableInstances,
    NoStarIsType,
    GeneralizedNewtypeDeriving,
    StandaloneDeriving,
    ScopedTypeVariables,
    ExplicitNamespaces,
    TypeOperators,
    KindSignatures,
    DataKinds,
    RankNTypes,
    TypeFamilies,
    FlexibleContexts,
    MultiParamTypeClasses,
    ConstraintKinds,
    FlexibleInstances
#-}
-- | An Exponential Family 'Harmonium' is a product exponential family with a
-- particular bilinear structure (<https://papers.nips.cc/paper/2672-exponential-family-harmoniums-with-an-application-to-information-retrieval Welling, et al., 2005>).
-- A 'Mixture' model is a special case of harmonium.
module Goal.Graphical.Models.Harmonium
    (
    -- * Harmoniums
      AffineHarmonium (AffineHarmonium)
    , Harmonium
    -- ** Constuction
    , splitHarmonium
    , joinHarmonium
    -- ** Manipulation
    , transposeHarmonium
    -- ** Evaluation
    , expectationStep
    -- ** Sampling
    , initialPass
    , gibbsPass
    -- ** Mixture Models
    , Mixture
    , AffineMixture
    , joinNaturalMixture
    , splitNaturalMixture
    , joinMeanMixture
    , splitMeanMixture
    , joinSourceMixture
    , splitSourceMixture
    -- ** Linear Gaussian Harmoniums
    , LinearGaussianHarmonium
    -- ** Conjugated Harmoniums
    , ConjugatedLikelihood (conjugationParameters)
    , joinConjugatedHarmonium
    , splitConjugatedHarmonium
    ) where

--- Imports ---


import Goal.Core
import Goal.Geometry
import Goal.Probability

import Goal.Graphical.Models

import qualified Goal.Core.Vector.Storable as S


--- Types ---


-- | A 2-layer harmonium.
newtype AffineHarmonium f y x z w = AffineHarmonium (Affine f y z x, w)

deriving instance (Manifold z, Manifold (f y x), Manifold w)
  => Manifold (AffineHarmonium f y x z w)
deriving instance (Manifold z, Manifold (f y x), Manifold w)
  => Product (AffineHarmonium f y x z w)

type Harmonium f z w = AffineHarmonium f z w z w

type instance Observation (AffineHarmonium f y x z w) = SamplePoint z

-- | A 'Mixture' model is simply a 'AffineHarmonium' where the latent variable is
-- 'Categorical'.
type Mixture z k = Harmonium Tensor z (Categorical k)

-- | A 'Mixture' where only a subset of the component parameters are mixed.
type AffineMixture y z k =
    AffineHarmonium Tensor y (Categorical k) z (Categorical k)

type LinearGaussianHarmonium n k =
    AffineHarmonium Tensor (MVNMean n) (MVNMean k) (MultivariateNormal n) (MultivariateNormal k)


--- Classes ---


-- | The conjugation parameters of a conjugated likelihood.
class ( ExponentialFamily z, ExponentialFamily w, Map Natural f y x
      , Translation z y , Translation w x
      , SamplePoint y ~ SamplePoint z, SamplePoint x ~ SamplePoint w )
  => ConjugatedLikelihood f y x z w where
    conjugationParameters
        :: Natural # Affine f y z x -- ^ Categorical likelihood
        -> (Double, Natural # w) -- ^ Conjugation parameters


--- Functions ---


-- Construction --

-- | Creates a 'Harmonium' from component parameters.
joinHarmonium
    :: (Manifold w, Manifold z, Manifold (f y x))
    => c # z -- ^ Visible layer biases
    -> c # f y x -- ^ ^ Interaction parameters
    -> c # w -- ^ Hidden layer Biases
    -> c # AffineHarmonium f y x z w -- ^ Harmonium
joinHarmonium :: (c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium c # z
nz c # f y x
nyx = (c # First (AffineHarmonium f y x z w))
-> (c # Second (AffineHarmonium f y x z w))
-> c # AffineHarmonium f y x z w
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join ((c # First (Affine f y z x))
-> (c # Second (Affine f y z x)) -> c # Affine f y z x
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join c # z
c # First (Affine f y z x)
nz c # f y x
c # Second (Affine f y z x)
nyx)

-- | Splits a 'Harmonium' into component parameters.
splitHarmonium
    :: (Manifold z, Manifold (f y x), Manifold w)
    => c # AffineHarmonium f y x z w -- ^ Harmonium
    -> (c # z, c # f y x, c # w) -- ^ Biases and interaction parameters
splitHarmonium :: (c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # AffineHarmonium f y x z w
hrm =
    let (c # Affine f y z x
fzx,c # w
nw) = (c # AffineHarmonium f y x z w)
-> (c # First (AffineHarmonium f y x z w),
    c # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # AffineHarmonium f y x z w
hrm
        (c # z
nz,c # f y x
nyx) = (c # Affine f y z x)
-> (c # First (Affine f y z x), c # Second (Affine f y z x))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # Affine f y z x
fzx
     in (c # z
nz,c # f y x
nyx,c # w
nw)

-- | Build a mixture model in source coordinates.
joinSourceMixture
    :: (KnownNat k, Manifold z)
    => S.Vector (k+1) (Source # z) -- ^ Mixture components
    -> Source # Categorical k -- ^ Weights
    -> Source # Mixture z k
joinSourceMixture :: Vector (k + 1) (Source # z)
-> (Source # Categorical k) -> Source # Mixture z k
joinSourceMixture Vector (k + 1) (Source # z)
szs Source # Categorical k
sx =
    let (Vector (1 + 0) (Source # z)
sz,Vector k (Source # z)
szs') = Vector ((1 + 0) + k) (Source # z)
-> (Vector (1 + 0) (Source # z), Vector k (Source # z))
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt Vector (k + 1) (Source # z)
Vector ((1 + 0) + k) (Source # z)
szs
        aff :: Source # Affine Tensor z z (Categorical k)
aff = (Source # First (Affine Tensor z z (Categorical k)))
-> (Source # Second (Affine Tensor z z (Categorical k)))
-> Source # Affine Tensor z z (Categorical k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join (Vector (1 + 0) (Source # z) -> Source # z
forall (n :: Nat) a. Storable a => Vector (1 + n) a -> a
S.head Vector (1 + 0) (Source # z)
sz) (Vector (Dimension (Categorical k)) (Source # z)
-> Source # Tensor z (Categorical k)
forall x y c.
(Manifold x, Manifold y) =>
Vector (Dimension x) (c # y) -> c # Tensor y x
fromColumns Vector k (Source # z)
Vector (Dimension (Categorical k)) (Source # z)
szs')
     in (Source # First (Mixture z k))
-> (Source # Second (Mixture z k)) -> Source # Mixture z k
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Source # First (Mixture z k)
Source # Affine Tensor z z (Categorical k)
aff Source # Second (Mixture z k)
Source # Categorical k
sx

-- | Build a mixture model in source coordinates.
splitSourceMixture
    :: (KnownNat k, Manifold z)
    => Source # Mixture z k
    -> (S.Vector (k+1) (Source # z), Source # Categorical k)
splitSourceMixture :: (Source # Mixture z k)
-> (Vector (k + 1) (Source # z), Source # Categorical k)
splitSourceMixture Source # Mixture z k
mxmdl =
    let (Source # Affine Tensor z z (Categorical k)
aff,Source # Categorical k
sx) = (Source # Mixture z k)
-> (Source # First (Mixture z k), Source # Second (Mixture z k))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Source # Mixture z k
mxmdl
        (Source # z
sz0,Source # Tensor z (Categorical k)
szs0') = (Source # Affine Tensor z z (Categorical k))
-> (Source # First (Affine Tensor z z (Categorical k)),
    Source # Second (Affine Tensor z z (Categorical k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Source # Affine Tensor z z (Categorical k)
aff
     in ((Source # z)
-> Vector k (Source # z) -> Vector (1 + k) (Source # z)
forall (n :: Nat) a.
Storable a =>
a -> Vector n a -> Vector (1 + n) a
S.cons Source # z
sz0 (Vector k (Source # z) -> Vector (1 + k) (Source # z))
-> Vector k (Source # z) -> Vector (1 + k) (Source # z)
forall a b. (a -> b) -> a -> b
$ (Source # Tensor z (Categorical k))
-> Vector (Dimension (Categorical k)) (Source # z)
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Vector (Dimension x) (c # y)
toColumns Source # Tensor z (Categorical k)
szs0' ,Source # Categorical k
sx)

-- | Build a mixture model in mean coordinates.
joinMeanMixture
    :: (KnownNat k, Manifold z)
    => S.Vector (k+1) (Mean # z) -- ^ Mixture components
    -> Mean # Categorical k -- ^ Weights
    -> Mean # Mixture z k
joinMeanMixture :: Vector (k + 1) (Mean # z)
-> (Mean # Categorical k) -> Mean # Mixture z k
joinMeanMixture Vector (k + 1) (Mean # z)
mzs Mean # Categorical k
mx =
    let wghts :: Vector (k + 1) Double
wghts = (Mean # Categorical k) -> Vector (k + 1) Double
forall c (n :: Nat).
Transition c Source (Categorical n) =>
(c # Categorical n) -> Vector (n + 1) Double
categoricalWeights Mean # Categorical k
mx
        wmzs :: Vector (k + 1) (Mean # z)
wmzs = (Double -> (Mean # z) -> Mean # z)
-> Vector (k + 1) Double
-> Vector (k + 1) (Mean # z)
-> Vector (k + 1) (Mean # z)
forall a b c (n :: Nat).
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
S.zipWith Double -> (Mean # z) -> Mean # z
forall c x. Double -> (c # x) -> c # x
(.>) Vector (k + 1) Double
wghts Vector (k + 1) (Mean # z)
mzs
        mz :: Mean # z
mz = ((Mean # z) -> (Mean # z) -> Mean # z)
-> Vector (k + 1) (Mean # z) -> Mean # z
forall a (n :: Nat).
Storable a =>
(a -> a -> a) -> Vector (n + 1) a -> a
S.foldr1 (Mean # z) -> (Mean # z) -> Mean # z
forall a. Num a => a -> a -> a
(+) Vector (k + 1) (Mean # z)
wmzs
        twmzs :: Vector k (Mean # z)
twmzs = Vector (1 + k) (Mean # z) -> Vector k (Mean # z)
forall (n :: Nat) a. Storable a => Vector (1 + n) a -> Vector n a
S.tail Vector (k + 1) (Mean # z)
Vector (1 + k) (Mean # z)
wmzs
        mzx :: Mean # Tensor z (Categorical k)
mzx = (Mean # Tensor (Categorical k) z)
-> Mean # Tensor z (Categorical k)
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
(c # f y x) -> c # f x y
transpose ((Mean # Tensor (Categorical k) z)
 -> Mean # Tensor z (Categorical k))
-> (Vector k (Mean # z) -> Mean # Tensor (Categorical k) z)
-> Vector k (Mean # z)
-> Mean # Tensor z (Categorical k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector k (Mean # z) -> Mean # Tensor (Categorical k) z
forall x y c.
(Manifold x, Manifold y) =>
Vector (Dimension y) (c # x) -> c # Tensor y x
fromRows (Vector k (Mean # z) -> Mean # Tensor z (Categorical k))
-> Vector k (Mean # z) -> Mean # Tensor z (Categorical k)
forall a b. (a -> b) -> a -> b
$ Vector k (Mean # z)
twmzs
     in (Mean # z)
-> (Mean # Tensor z (Categorical k))
-> (Mean # Categorical k)
-> Mean # Mixture z k
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium Mean # z
mz Mean # Tensor z (Categorical k)
mzx Mean # Categorical k
mx

-- | Split a mixture model in mean coordinates.
splitMeanMixture
    :: ( KnownNat k, DuallyFlatExponentialFamily z )
    => Mean # Mixture z k
    -> (S.Vector (k+1) (Mean # z), Mean # Categorical k)
splitMeanMixture :: (Mean # Mixture z k)
-> (Vector (k + 1) (Mean # z), Mean # Categorical k)
splitMeanMixture Mean # Mixture z k
hrm =
    let (Mean # z
mz,Mean # Tensor z (Categorical k)
mzx,Mean # Categorical k
mx) = (Mean # Mixture z k)
-> (Mean # z, Mean # Tensor z (Categorical k),
    Mean # Categorical k)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium Mean # Mixture z k
hrm
        twmzs :: Vector (Dimension (Categorical k)) (Mean # z)
twmzs = (Mean # Tensor (Categorical k) z)
-> Vector (Dimension (Categorical k)) (Mean # z)
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Vector (Dimension y) (c # x)
toRows ((Mean # Tensor (Categorical k) z)
 -> Vector (Dimension (Categorical k)) (Mean # z))
-> (Mean # Tensor (Categorical k) z)
-> Vector (Dimension (Categorical k)) (Mean # z)
forall a b. (a -> b) -> a -> b
$ (Mean # Tensor z (Categorical k))
-> Mean # Tensor (Categorical k) z
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
(c # f y x) -> c # f x y
transpose Mean # Tensor z (Categorical k)
mzx
        wmzs :: Vector (1 + k) (Mean # z)
wmzs = (Mean # z) -> Vector k (Mean # z) -> Vector (1 + k) (Mean # z)
forall (n :: Nat) a.
Storable a =>
a -> Vector n a -> Vector (1 + n) a
S.cons (Mean # z
mz (Mean # z) -> (Mean # z) -> Mean # z
forall a. Num a => a -> a -> a
- ((Mean # z) -> (Mean # z) -> Mean # z)
-> (Mean # z) -> Vector k (Mean # z) -> Mean # z
forall a b (n :: Nat).
Storable a =>
(a -> b -> b) -> b -> Vector n a -> b
S.foldr (Mean # z) -> (Mean # z) -> Mean # z
forall a. Num a => a -> a -> a
(+) Mean # z
0 Vector k (Mean # z)
Vector (Dimension (Categorical k)) (Mean # z)
twmzs) Vector k (Mean # z)
Vector (Dimension (Categorical k)) (Mean # z)
twmzs
        wghts :: Vector (k + 1) Double
wghts = (Mean # Categorical k) -> Vector (k + 1) Double
forall c (n :: Nat).
Transition c Source (Categorical n) =>
(c # Categorical n) -> Vector (n + 1) Double
categoricalWeights Mean # Categorical k
mx
        mzs :: Vector (1 + k) (Mean # z)
mzs = (Double -> (Mean # z) -> Mean # z)
-> Vector (1 + k) Double
-> Vector (1 + k) (Mean # z)
-> Vector (1 + k) (Mean # z)
forall a b c (n :: Nat).
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
S.zipWith Double -> (Mean # z) -> Mean # z
forall c x. Double -> (c # x) -> c # x
(/>) Vector (k + 1) Double
Vector (1 + k) Double
wghts Vector (1 + k) (Mean # z)
wmzs
     in (Vector (k + 1) (Mean # z)
Vector (1 + k) (Mean # z)
mzs,Mean # Categorical k
mx)

-- | A convenience function for building a categorical harmonium/mixture model.
joinNaturalMixture
    :: forall k z . ( KnownNat k, LegendreExponentialFamily z )
    => S.Vector (k+1) (Natural # z) -- ^ Mixture components
    -> Natural # Categorical k -- ^ Weights
    -> Natural # Mixture z k -- ^ Mixture Model
joinNaturalMixture :: Vector (k + 1) (Natural # z)
-> (Natural # Categorical k) -> Natural # Mixture z k
joinNaturalMixture Vector (k + 1) (Natural # z)
nzs0 Natural # Categorical k
nx0 =
    let nz0 :: S.Vector 1 (Natural # z)
        (Vector 1 (Natural # z)
nz0,Vector k (Natural # z)
nzs0') = Vector (1 + k) (Natural # z)
-> (Vector 1 (Natural # z), Vector k (Natural # z))
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt Vector (k + 1) (Natural # z)
Vector (1 + k) (Natural # z)
nzs0
        nz :: Natural # z
nz = Vector (1 + 0) (Natural # z) -> Natural # z
forall (n :: Nat) a. Storable a => Vector (1 + n) a -> a
S.head Vector 1 (Natural # z)
Vector (1 + 0) (Natural # z)
nz0
        nzs :: Vector k (Natural # z)
nzs = ((Natural # z) -> Natural # z)
-> Vector k (Natural # z) -> Vector k (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map ((Natural # z) -> (Natural # z) -> Natural # z
forall a. Num a => a -> a -> a
subtract Natural # z
nz) Vector k (Natural # z)
nzs0'
        nzx :: Natural # Tensor z (Categorical k)
nzx = Matrix Vector (Dimension z) k Double
-> Natural # Tensor z (Categorical k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix (Matrix Vector (Dimension z) k Double
 -> Natural # Tensor z (Categorical k))
-> (Vector k (Vector (Dimension z) Double)
    -> Matrix Vector (Dimension z) k Double)
-> Vector k (Vector (Dimension z) Double)
-> Natural # Tensor z (Categorical k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector k (Vector (Dimension z) Double)
-> Matrix Vector (Dimension z) k Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector n (Vector m x) -> Matrix m n x
S.fromColumns (Vector k (Vector (Dimension z) Double)
 -> Natural # Tensor z (Categorical k))
-> Vector k (Vector (Dimension z) Double)
-> Natural # Tensor z (Categorical k)
forall a b. (a -> b) -> a -> b
$ ((Natural # z) -> Vector (Dimension z) Double)
-> Vector k (Natural # z) -> Vector k (Vector (Dimension z) Double)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Natural # z) -> Vector (Dimension z) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates Vector k (Natural # z)
nzs
        affzx :: Natural # Affine Tensor z z (Categorical k)
affzx = (Natural # First (Affine Tensor z z (Categorical k)))
-> (Natural # Second (Affine Tensor z z (Categorical k)))
-> Natural # Affine Tensor z z (Categorical k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # z
Natural # First (Affine Tensor z z (Categorical k))
nz Natural # Second (Affine Tensor z z (Categorical k))
Natural # Tensor z (Categorical k)
nzx
        rprms :: Natural # Categorical k
rprms = (Double, Natural # Categorical k) -> Natural # Categorical k
forall a b. (a, b) -> b
snd ((Double, Natural # Categorical k) -> Natural # Categorical k)
-> (Double, Natural # Categorical k) -> Natural # Categorical k
forall a b. (a -> b) -> a -> b
$ (Natural # Affine Tensor z z (Categorical k))
-> (Double, Natural # Categorical k)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine Tensor z z (Categorical k)
affzx
        nx :: Natural # Categorical k
nx = Natural # Categorical k
nx0 (Natural # Categorical k)
-> (Natural # Categorical k) -> Natural # Categorical k
forall a. Num a => a -> a -> a
- Natural # Categorical k
rprms
     in (Natural # z)
-> (Natural # Tensor z (Categorical k))
-> (Natural # Categorical k)
-> Natural # Mixture z k
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium Natural # z
nz Natural # Tensor z (Categorical k)
nzx Natural # Categorical k
nx

-- | A convenience function for deconstructing a categorical harmonium/mixture model.
splitNaturalMixture
    :: forall k z . ( KnownNat k, LegendreExponentialFamily z )
    => Natural # Mixture z k -- ^ Categorical harmonium
    -> (S.Vector (k+1) (Natural # z), Natural # Categorical k) -- ^ (components, weights)
splitNaturalMixture :: (Natural # Mixture z k)
-> (Vector (k + 1) (Natural # z), Natural # Categorical k)
splitNaturalMixture Natural # Mixture z k
hrm =
    let (Natural # z
nz,Natural # Tensor z (Categorical k)
nzx,Natural # Categorical k
nx) = (Natural # Mixture z k)
-> (Natural # z, Natural # Tensor z (Categorical k),
    Natural # Categorical k)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium Natural # Mixture z k
hrm
        affzx :: Natural # Affine Tensor z z (Categorical k)
affzx = (Natural # First (Affine Tensor z z (Categorical k)))
-> (Natural # Second (Affine Tensor z z (Categorical k)))
-> Natural # Affine Tensor z z (Categorical k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # z
Natural # First (Affine Tensor z z (Categorical k))
nz Natural # Second (Affine Tensor z z (Categorical k))
Natural # Tensor z (Categorical k)
nzx
        rprms :: Natural # Categorical k
rprms = (Double, Natural # Categorical k) -> Natural # Categorical k
forall a b. (a, b) -> b
snd ((Double, Natural # Categorical k) -> Natural # Categorical k)
-> (Double, Natural # Categorical k) -> Natural # Categorical k
forall a b. (a -> b) -> a -> b
$ (Natural # Affine Tensor z z (Categorical k))
-> (Double, Natural # Categorical k)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine Tensor z z (Categorical k)
affzx
        nx0 :: Natural # Categorical k
nx0 = Natural # Categorical k
nx (Natural # Categorical k)
-> (Natural # Categorical k) -> Natural # Categorical k
forall a. Num a => a -> a -> a
+ Natural # Categorical k
rprms
        nzs :: Vector k (Natural # z)
nzs = (Vector (Dimension z) Double -> Natural # z)
-> Vector k (Vector (Dimension z) Double) -> Vector k (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map Vector (Dimension z) Double -> Natural # z
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector k (Vector (Dimension z) Double) -> Vector k (Natural # z))
-> (Matrix (Dimension z) k Double
    -> Vector k (Vector (Dimension z) Double))
-> Matrix (Dimension z) k Double
-> Vector k (Natural # z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix (Dimension z) k Double
-> Vector k (Vector (Dimension z) Double)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n (Vector m x)
S.toColumns (Matrix (Dimension z) k Double -> Vector k (Natural # z))
-> Matrix (Dimension z) k Double -> Vector k (Natural # z)
forall a b. (a -> b) -> a -> b
$ (Natural # Tensor z (Categorical k))
-> Matrix (Dimension z) (Dimension (Categorical k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Natural # Tensor z (Categorical k)
nzx
        nzs0' :: Vector k (Natural # z)
nzs0' = ((Natural # z) -> Natural # z)
-> Vector k (Natural # z) -> Vector k (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map ((Natural # z) -> (Natural # z) -> Natural # z
forall a. Num a => a -> a -> a
+ Natural # z
nz) Vector k (Natural # z)
nzs
     in ((Natural # z)
-> Vector k (Natural # z) -> Vector (1 + k) (Natural # z)
forall (n :: Nat) a.
Storable a =>
a -> Vector n a -> Vector (1 + n) a
S.cons Natural # z
nz Vector k (Natural # z)
nzs0',Natural # Categorical k
nx0)


-- Manipulation --

-- | Swap the biases and 'transpose' the interaction parameters of the given 'Harmonium'.
transposeHarmonium
    :: (Bilinear f y x, Manifold z, Manifold w)
    => c # AffineHarmonium f y x z w
    -> c # AffineHarmonium f x y w z
transposeHarmonium :: (c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium c # AffineHarmonium f y x z w
hrm =
        let (c # z
nz,c # f y x
nyx,c # w
nw) = (c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # AffineHarmonium f y x z w
hrm
         in (c # w) -> (c # f x y) -> (c # z) -> c # AffineHarmonium f x y w z
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium c # w
nw ((c # f y x) -> c # f x y
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
(c # f y x) -> c # f x y
transpose c # f y x
nyx) c # z
nz

-- Evaluation --

-- | Computes the joint expectations of a harmonium based on a sample from the
-- observable layer.
expectationStep
    :: ( ExponentialFamily z, Map Natural f x y, Bilinear f y x
       , Translation z y, Translation w x, LegendreExponentialFamily w )
    => Sample z -- ^ Model Samples
    -> Natural # AffineHarmonium f y x z w -- ^ Harmonium
    -> Mean # AffineHarmonium f y x z w -- ^ Harmonium expected sufficient statistics
expectationStep :: Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
expectationStep Sample z
zs Natural # AffineHarmonium f y x z w
hrm =
    let mzs :: [Mean # z]
mzs = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint z -> Mean # z) -> Sample z -> [Mean # z]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample z
zs
        mys :: [Mean # y]
mys = (Mean # z) -> Mean # y
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # z) -> Mean # y) -> [Mean # z] -> [Mean # y]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # z]
mzs
        pstr :: Natural # Affine f x w y
pstr = (Natural # Affine f x w y, Natural # z) -> Natural # Affine f x w y
forall a b. (a, b) -> a
fst ((Natural # Affine f x w y, Natural # z)
 -> Natural # Affine f x w y)
-> ((Natural # AffineHarmonium f x y w z)
    -> (Natural # Affine f x w y, Natural # z))
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x y w z)
-> (Natural # Affine f x w y, Natural # z)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f x y w z) -> Natural # Affine f x w y)
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f x y w z
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium Natural # AffineHarmonium f y x z w
hrm
        mws :: [Mean # w]
mws = (Natural # w) -> Mean # w
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Natural # w) -> Mean # w) -> [Natural # w] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural # Affine f x w y
pstr (Natural # Affine f x w y) -> [Natural #* y] -> [Natural # w]
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> [c #* x] -> [c # y]
>$> [Natural #* y]
[Mean # y]
mys
        mxs :: [Mean # x]
mxs = (Mean # w) -> Mean # x
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # w) -> Mean # x) -> [Mean # w] -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # w]
mws
        myx :: Mean # f y x
myx = [Mean # y] -> [Mean # x] -> Mean # f y x
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
[c # y] -> [c # x] -> c # f y x
(>$<) [Mean # y]
mys [Mean # x]
mxs
     in (Mean # z)
-> (Mean # f y x) -> (Mean # w) -> Mean # AffineHarmonium f y x z w
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium ([Mean # z] -> Mean # z
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # z]
mzs) Mean # f y x
myx ((Mean # w) -> Mean # AffineHarmonium f y x z w)
-> (Mean # w) -> Mean # AffineHarmonium f y x z w
forall a b. (a -> b) -> a -> b
$ [Mean # w] -> Mean # w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # w]
mws

---- Sampling --

-- | Initialize a Gibbs chain from a set of observations.
initialPass
    :: forall f x y z w
    . ( ExponentialFamily z, Map Natural f x y, Manifold w
      , SamplePoint y ~ SamplePoint z, Translation w x, Generative Natural w
      , ExponentialFamily y, Bilinear f y x, LegendreExponentialFamily w )
    => Natural # AffineHarmonium f y x z w -- ^ Harmonium
    -> Sample z -- ^ Model Samples
    -> Random (Sample (z, w))
initialPass :: (Natural # AffineHarmonium f y x z w)
-> Sample z -> Random (Sample (z, w))
initialPass Natural # AffineHarmonium f y x z w
hrm Sample z
zs = do
    let pstr :: Natural # Affine f x w y
pstr = (Natural # Affine f x w y, Natural # z) -> Natural # Affine f x w y
forall a b. (a, b) -> a
fst ((Natural # Affine f x w y, Natural # z)
 -> Natural # Affine f x w y)
-> ((Natural # AffineHarmonium f x y w z)
    -> (Natural # Affine f x w y, Natural # z))
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x y w z)
-> (Natural # Affine f x w y, Natural # z)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f x y w z) -> Natural # Affine f x w y)
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f x y w z
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium Natural # AffineHarmonium f y x z w
hrm
    [SamplePoint w]
ws <- (Point Natural w -> Random (SamplePoint w))
-> [Point Natural w] -> Random [SamplePoint w]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Point Natural w -> Random (SamplePoint w)
forall c x. Generative c x => Point c x -> Random (SamplePoint x)
samplePoint ([Point Natural w] -> Random [SamplePoint w])
-> [Point Natural w] -> Random [SamplePoint w]
forall a b. (a -> b) -> a -> b
$ Natural # Affine f x w y
pstr (Natural # Affine f x w y) -> Sample y -> [Point Natural w]
forall (f :: Type -> Type -> Type) y x.
(Map Natural f y x, ExponentialFamily x) =>
(Natural # f y x) -> Sample x -> [Natural # y]
>$>* Sample y
Sample z
zs
    [(SamplePoint z, SamplePoint w)]
-> Random [(SamplePoint z, SamplePoint w)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([(SamplePoint z, SamplePoint w)]
 -> Random [(SamplePoint z, SamplePoint w)])
-> [(SamplePoint z, SamplePoint w)]
-> Random [(SamplePoint z, SamplePoint w)]
forall a b. (a -> b) -> a -> b
$ Sample z -> [SamplePoint w] -> [(SamplePoint z, SamplePoint w)]
forall a b. [a] -> [b] -> [(a, b)]
zip Sample z
zs [SamplePoint w]
ws

-- | Update a 'Sample' with Gibbs sampling.
gibbsPass
    :: ( ExponentialFamily z, Map Natural f x y, Translation z y
       , Translation w x, SamplePoint z ~ SamplePoint y, Generative Natural w
       , ExponentialFamily y, SamplePoint x ~ SamplePoint w, Bilinear f y x
       , Map Natural f y x, ExponentialFamily x, Generative Natural z )
    => Natural # AffineHarmonium f y x z w -- ^ Harmonium
    -> Sample (z, w)
    -> Random (Sample (z, w))
gibbsPass :: (Natural # AffineHarmonium f y x z w)
-> Sample (z, w) -> Random (Sample (z, w))
gibbsPass Natural # AffineHarmonium f y x z w
hrm Sample (z, w)
zws = do
    let ws :: [SamplePoint w]
ws = (SamplePoint y, SamplePoint w) -> SamplePoint w
forall a b. (a, b) -> b
snd ((SamplePoint y, SamplePoint w) -> SamplePoint w)
-> [(SamplePoint y, SamplePoint w)] -> [SamplePoint w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SamplePoint y, SamplePoint w)]
Sample (z, w)
zws
        pstr :: Natural # Affine f x w y
pstr = (Natural # Affine f x w y, Natural # z) -> Natural # Affine f x w y
forall a b. (a, b) -> a
fst ((Natural # Affine f x w y, Natural # z)
 -> Natural # Affine f x w y)
-> ((Natural # AffineHarmonium f x y w z)
    -> (Natural # Affine f x w y, Natural # z))
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x y w z)
-> (Natural # Affine f x w y, Natural # z)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f x y w z) -> Natural # Affine f x w y)
-> (Natural # AffineHarmonium f x y w z)
-> Natural # Affine f x w y
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f x y w z
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium Natural # AffineHarmonium f y x z w
hrm
        lkl :: Natural # Affine f y z x
lkl = (Natural # Affine f y z x, Natural # w) -> Natural # Affine f y z x
forall a b. (a, b) -> a
fst ((Natural # Affine f y z x, Natural # w)
 -> Natural # Affine f y z x)
-> (Natural # Affine f y z x, Natural # w)
-> Natural # Affine f y z x
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
    [SamplePoint y]
zs' <- ((Natural # z) -> Random (SamplePoint y))
-> [Natural # z] -> Random [SamplePoint y]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Natural # z) -> Random (SamplePoint y)
forall c x. Generative c x => Point c x -> Random (SamplePoint x)
samplePoint ([Natural # z] -> Random [SamplePoint y])
-> [Natural # z] -> Random [SamplePoint y]
forall a b. (a -> b) -> a -> b
$ Natural # Affine f y z x
lkl (Natural # Affine f y z x) -> Sample x -> [Natural # z]
forall (f :: Type -> Type -> Type) y x.
(Map Natural f y x, ExponentialFamily x) =>
(Natural # f y x) -> Sample x -> [Natural # y]
>$>* Sample x
[SamplePoint w]
ws
    [SamplePoint w]
ws' <- ((Natural # w) -> Random (SamplePoint w))
-> [Natural # w] -> Random [SamplePoint w]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Natural # w) -> Random (SamplePoint w)
forall c x. Generative c x => Point c x -> Random (SamplePoint x)
samplePoint ([Natural # w] -> Random [SamplePoint w])
-> [Natural # w] -> Random [SamplePoint w]
forall a b. (a -> b) -> a -> b
$ Natural # Affine f x w y
pstr (Natural # Affine f x w y) -> [SamplePoint y] -> [Natural # w]
forall (f :: Type -> Type -> Type) y x.
(Map Natural f y x, ExponentialFamily x) =>
(Natural # f y x) -> Sample x -> [Natural # y]
>$>* [SamplePoint y]
zs'
    [(SamplePoint y, SamplePoint w)]
-> Random [(SamplePoint y, SamplePoint w)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([(SamplePoint y, SamplePoint w)]
 -> Random [(SamplePoint y, SamplePoint w)])
-> [(SamplePoint y, SamplePoint w)]
-> Random [(SamplePoint y, SamplePoint w)]
forall a b. (a -> b) -> a -> b
$ [SamplePoint y]
-> [SamplePoint w] -> [(SamplePoint y, SamplePoint w)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SamplePoint y]
zs' [SamplePoint w]
ws'

-- Conjugation --

-- | The conjugation parameters of a conjugated `Harmonium`.
harmoniumConjugationParameters
    :: ConjugatedLikelihood f y x z w
    => Natural # AffineHarmonium f y x z w -- ^ Categorical likelihood
    -> (Double, Natural # w) -- ^ Conjugation parameters
harmoniumConjugationParameters :: (Natural # AffineHarmonium f y x z w) -> (Double, Natural # w)
harmoniumConjugationParameters Natural # AffineHarmonium f y x z w
hrm =
    (Natural # Affine f y z x) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters ((Natural # Affine f y z x) -> (Double, Natural # w))
-> ((Natural # Affine f y z x, Natural # w)
    -> Natural # Affine f y z x)
-> (Natural # Affine f y z x, Natural # w)
-> (Double, Natural # w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # Affine f y z x, Natural # w) -> Natural # Affine f y z x
forall a b. (a, b) -> a
fst ((Natural # Affine f y z x, Natural # w) -> (Double, Natural # w))
-> (Natural # Affine f y z x, Natural # w) -> (Double, Natural # w)
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm

-- | The conjugation parameters of a conjugated `Harmonium`.
splitConjugatedHarmonium
    :: ConjugatedLikelihood f y x z w
    => Natural # AffineHarmonium f y x z w
    -> (Natural # Affine f y z x, Natural # w) -- ^ Conjugation parameters
splitConjugatedHarmonium :: (Natural # AffineHarmonium f y x z w)
-> (Natural # Affine f y z x, Natural # w)
splitConjugatedHarmonium Natural # AffineHarmonium f y x z w
hrm =
    let (Natural # Affine f y z x
lkl,Natural # w
nw) = (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
        cw :: Natural # w
cw = (Double, Natural # w) -> Natural # w
forall a b. (a, b) -> b
snd ((Double, Natural # w) -> Natural # w)
-> (Double, Natural # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Natural # Affine f y z x) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine f y z x
lkl
     in (Natural # Affine f y z x
lkl,Natural # w
nw (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
+ Natural # w
cw)

-- | The conjugation parameters of a conjugated `Harmonium`.
joinConjugatedHarmonium
    :: ConjugatedLikelihood f y x z w
    => Natural # Affine f y z x -- ^ Conjugation parameters
    -> Natural # w
    -> Natural # AffineHarmonium f y x z w -- ^ Categorical likelihood
joinConjugatedHarmonium :: (Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
joinConjugatedHarmonium Natural # Affine f y z x
lkl Natural # w
nw =
    let cw :: Natural # w
cw = (Double, Natural # w) -> Natural # w
forall a b. (a, b) -> b
snd ((Double, Natural # w) -> Natural # w)
-> (Double, Natural # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Natural # Affine f y z x) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine f y z x
lkl
     in (Natural # First (AffineHarmonium f y x z w))
-> (Natural # Second (AffineHarmonium f y x z w))
-> Natural # AffineHarmonium f y x z w
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # First (AffineHarmonium f y x z w)
Natural # Affine f y z x
lkl ((Natural # Second (AffineHarmonium f y x z w))
 -> Natural # AffineHarmonium f y x z w)
-> (Natural # Second (AffineHarmonium f y x z w))
-> Natural # AffineHarmonium f y x z w
forall a b. (a -> b) -> a -> b
$ Natural # w
nw (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
- Natural # w
cw

-- | The conjugation parameters of a conjugated `Harmonium`.
sampleConjugated
    :: forall f y x z w
     . ( ConjugatedLikelihood f y x z w, Generative Natural w
       , Generative Natural z, Map Natural f y x )
    => Int
    -> Natural # AffineHarmonium f y x z w -- ^ Categorical likelihood
    -> Random (Sample (z,w)) -- ^ Conjugation parameters
sampleConjugated :: Int
-> (Natural # AffineHarmonium f y x z w) -> Random (Sample (z, w))
sampleConjugated Int
n Natural # AffineHarmonium f y x z w
hrm = do
    let (Natural # Affine f y z x
lkl,Natural # w
nw) = (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
        nw' :: Natural # w
nw' = Natural # w
nw (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
+ (Double, Natural # w) -> Natural # w
forall a b. (a, b) -> b
snd ((Natural # Affine f y z x) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine f y z x
lkl)
    [SamplePoint w]
ws <- Int -> (Natural # w) -> Random [SamplePoint w]
forall c x. Generative c x => Int -> Point c x -> Random (Sample x)
sample Int
n Natural # w
nw'
    let mws :: [Mean # w]
        mws :: [Mean # w]
mws = SamplePoint w -> Mean # w
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint w -> Mean # w) -> [SamplePoint w] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SamplePoint w]
ws
    [SamplePoint z]
zs <- (Point Natural z -> Random (SamplePoint z))
-> [Point Natural z] -> Random [SamplePoint z]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Point Natural z -> Random (SamplePoint z)
forall c x. Generative c x => Point c x -> Random (SamplePoint x)
samplePoint ([Point Natural z] -> Random [SamplePoint z])
-> [Point Natural z] -> Random [SamplePoint z]
forall a b. (a -> b) -> a -> b
$ Natural # Affine f y z x
lkl (Natural # Affine f y z x) -> [Natural #* w] -> [Point Natural z]
forall c (f :: Type -> Type -> Type) y x z.
(Map c f y x, Translation z x) =>
(c # f y x) -> [c #* z] -> [c # y]
>$+> [Natural #* w]
[Mean # w]
mws
    [(SamplePoint z, SamplePoint w)]
-> Random [(SamplePoint z, SamplePoint w)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([(SamplePoint z, SamplePoint w)]
 -> Random [(SamplePoint z, SamplePoint w)])
-> [(SamplePoint z, SamplePoint w)]
-> Random [(SamplePoint z, SamplePoint w)]
forall a b. (a -> b) -> a -> b
$ [SamplePoint z]
-> [SamplePoint w] -> [(SamplePoint z, SamplePoint w)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SamplePoint z]
zs [SamplePoint w]
ws

-- | The conjugation parameters of a conjugated `Harmonium`.
conjugatedPotential
    :: ( LegendreExponentialFamily w, ConjugatedLikelihood f y x z w )
    => Natural # AffineHarmonium f y x z w -- ^ Categorical likelihood
    -> Double -- ^ Conjugation parameters
conjugatedPotential :: (Natural # AffineHarmonium f y x z w) -> Double
conjugatedPotential Natural # AffineHarmonium f y x z w
hrm = do
    let (Natural # Affine f y z x
lkl,Natural # w
nw) = (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
        (Double
rho0,Natural # w
rprms) = (Natural # Affine f y z x) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters Natural # Affine f y z x
lkl
     in (PotentialCoordinates w # w) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential (Natural # w
nw (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
+ Natural # w
rprms) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
rho0


--- Internal ---


-- Conjugation --

-- | The unnormalized density of a given 'Harmonium' 'Point'.
unnormalizedHarmoniumObservableLogDensity
    :: forall f y x z w
    . ( ExponentialFamily z, ExponentialFamily y
      , LegendreExponentialFamily w, Translation w x, Translation z y
      , Map Natural f x y, Bilinear f y x )
    => Natural # AffineHarmonium f y x z w
    -> Sample z
    -> [Double]
unnormalizedHarmoniumObservableLogDensity :: (Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
unnormalizedHarmoniumObservableLogDensity Natural # AffineHarmonium f y x z w
hrm Sample z
zs =
    let (Natural # Affine f x w y
pstr,Natural # z
nz) = (Natural # AffineHarmonium f x y w z)
-> (Natural # First (AffineHarmonium f x y w z),
    Natural # Second (AffineHarmonium f x y w z))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f x y w z)
 -> (Natural # First (AffineHarmonium f x y w z),
     Natural # Second (AffineHarmonium f x y w z)))
-> (Natural # AffineHarmonium f x y w z)
-> (Natural # First (AffineHarmonium f x y w z),
    Natural # Second (AffineHarmonium f x y w z))
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f x y w z
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium Natural # AffineHarmonium f y x z w
hrm
        mzs :: [Mean # z]
mzs = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint z -> Mean # z) -> Sample z -> [Mean # z]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample z
zs
        nrgs :: [Double]
nrgs = (Double -> Double -> Double) -> [Double] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) ((Natural # z) -> [Natural #* z] -> [Double]
forall x c. Manifold x => (c # x) -> [c #* x] -> [Double]
dotMap Natural # z
nz [Natural #* z]
[Mean # z]
mzs) ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Point Natural w -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential (Point Natural w -> Double) -> [Point Natural w] -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural # Affine f x w y
pstr (Natural # Affine f x w y) -> [Natural #* z] -> [Point Natural w]
forall c (f :: Type -> Type -> Type) y x z.
(Map c f y x, Translation z x) =>
(c # f y x) -> [c #* z] -> [c # y]
>$+> [Natural #* z]
[Mean # z]
mzs
     in (Double -> Double -> Double) -> [Double] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) [Double]
nrgs ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Proxy z -> SamplePoint z -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy z
forall k (t :: k). Proxy t
Proxy @ z) (SamplePoint z -> Double) -> Sample z -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample z
zs

--- | Computes the negative log-likelihood of a sample point of a conjugated harmonium.
logConjugatedDensities
    :: forall f x y z w
    . ( Bilinear f y x, Translation z y
      , LegendreExponentialFamily z, ExponentialFamily y
      , LegendreExponentialFamily w, Translation w x, Map Natural f x y)
      => (Double, Natural # w) -- ^ Conjugation Parameters
      -> Natural # AffineHarmonium f y x z w
      -> Sample z
      -> [Double]
logConjugatedDensities :: (Double, Natural # w)
-> (Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
logConjugatedDensities (Double
rho0,Natural # w
rprms) Natural # AffineHarmonium f y x z w
hrm Sample z
z =
    let udns :: [Double]
udns = (Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
forall (f :: Type -> Type -> Type) y x z w.
(ExponentialFamily z, ExponentialFamily y,
 LegendreExponentialFamily w, Translation w x, Translation z y,
 Map Natural f x y, Bilinear f y x) =>
(Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
unnormalizedHarmoniumObservableLogDensity Natural # AffineHarmonium f y x z w
hrm Sample z
z
        nx :: Natural # w
nx = (Natural # Affine f y z x, Natural # w) -> Natural # w
forall a b. (a, b) -> b
snd ((Natural # Affine f y z x, Natural # w) -> Natural # w)
-> (Natural # Affine f y z x, Natural # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
     in Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract ((PotentialCoordinates w # w) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential (Natural # w
nx (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
+ Natural # w
rprms) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
rho0) (Double -> Double) -> [Double] -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Double]
udns

-- Mixtures --

mixtureLikelihoodConjugationParameters
    :: (KnownNat k, LegendreExponentialFamily z, Translation z y)
    => Natural # Affine Tensor y z (Categorical k) -- ^ Categorical likelihood
    -> (Double, Natural # Categorical k) -- ^ Conjugation parameters
mixtureLikelihoodConjugationParameters :: (Natural # Affine Tensor y z (Categorical k))
-> (Double, Natural # Categorical k)
mixtureLikelihoodConjugationParameters Natural # Affine Tensor y z (Categorical k)
aff =
    let (Natural # z
nz,Natural # Tensor y (Categorical k)
nyx) = (Natural # Affine Tensor y z (Categorical k))
-> (Natural # First (Affine Tensor y z (Categorical k)),
    Natural # Second (Affine Tensor y z (Categorical k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # Affine Tensor y z (Categorical k)
aff
        rho0 :: Double
rho0 = (PotentialCoordinates z # z) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates z # z
Natural # z
nz
        rprms :: Vector k Double
rprms = ((Natural # y) -> Double)
-> Vector k (Natural # y) -> Vector k Double
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (\Natural # y
nyxi -> Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
rho0 (Double -> Double)
-> ((Natural # z) -> Double) -> (Natural # z) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # z) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential ((Natural # z) -> Double) -> (Natural # z) -> Double
forall a b. (a -> b) -> a -> b
$ Natural # z
nz (Natural # z) -> (Natural # y) -> Natural # z
forall z y c. Translation z y => (c # z) -> (c # y) -> c # z
>+> Natural # y
nyxi) (Vector k (Natural # y) -> Vector k Double)
-> Vector k (Natural # y) -> Vector k Double
forall a b. (a -> b) -> a -> b
$ (Natural # Tensor y (Categorical k))
-> Vector (Dimension (Categorical k)) (Natural # y)
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Vector (Dimension x) (c # y)
toColumns Natural # Tensor y (Categorical k)
nyx
     in (Double
rho0, Vector (Dimension (Categorical k)) Double
-> Natural # Categorical k
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector k Double
Vector (Dimension (Categorical k)) Double
rprms)

affineMixtureToMixture
    :: (KnownNat k, Manifold z, Manifold y, Translation z y)
    => Natural # AffineMixture y z k
    -> Natural # Mixture z k
affineMixtureToMixture :: (Natural # AffineMixture y z k) -> Natural # Mixture z k
affineMixtureToMixture Natural # AffineMixture y z k
lmxmdl =
    let (Natural # Affine Tensor y z (Categorical k)
flsk,Natural # Categorical k
nk) = (Natural # AffineMixture y z k)
-> (Natural # First (AffineMixture y z k),
    Natural # Second (AffineMixture y z k))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineMixture y z k
lmxmdl
        (Natural # z
nls,Natural # Tensor y (Categorical k)
nlk) = (Natural # Affine Tensor y z (Categorical k))
-> (Natural # First (Affine Tensor y z (Categorical k)),
    Natural # Second (Affine Tensor y z (Categorical k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # Affine Tensor y z (Categorical k)
flsk
        nlsk :: Natural # Tensor z (Categorical k)
nlsk = Vector Vector k (Natural # z) -> Natural # Tensor z (Categorical k)
forall x y c.
(Manifold x, Manifold y) =>
Vector (Dimension x) (c # y) -> c # Tensor y x
fromColumns (Vector Vector k (Natural # z)
 -> Natural # Tensor z (Categorical k))
-> (Vector k (Natural # y) -> Vector Vector k (Natural # z))
-> Vector k (Natural # y)
-> Natural # Tensor z (Categorical k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Natural # y) -> Natural # z)
-> Vector k (Natural # y) -> Vector Vector k (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Natural # z
0 (Natural # z) -> (Natural # y) -> Natural # z
forall z y c. Translation z y => (c # z) -> (c # y) -> c # z
>+>) (Vector k (Natural # y) -> Natural # Tensor z (Categorical k))
-> Vector k (Natural # y) -> Natural # Tensor z (Categorical k)
forall a b. (a -> b) -> a -> b
$ (Natural # Tensor y (Categorical k))
-> Vector (Dimension (Categorical k)) (Natural # y)
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Vector (Dimension x) (c # y)
toColumns Natural # Tensor y (Categorical k)
nlk
     in (Natural # First (Mixture z k))
-> (Natural # Second (Mixture z k)) -> Natural # Mixture z k
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join ((Natural # First (Affine Tensor z z (Categorical k)))
-> (Natural # Second (Affine Tensor z z (Categorical k)))
-> Natural # Affine Tensor z z (Categorical k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Natural # z
Natural # First (Affine Tensor z z (Categorical k))
nls Natural # Second (Affine Tensor z z (Categorical k))
Natural # Tensor z (Categorical k)
nlsk) Natural # Second (Mixture z k)
Natural # Categorical k
nk

mixtureToAffineMixture
    :: (KnownNat k, Manifold y, Manifold z, Translation z y)
    => Mean # Mixture z k
    -> Mean # AffineMixture y z k
mixtureToAffineMixture :: (Mean # Mixture z k) -> Mean # AffineMixture y z k
mixtureToAffineMixture Mean # Mixture z k
mxmdl =
    let (Mean # Affine Tensor z z (Categorical k)
flsk,Mean # Categorical k
mk) = (Mean # Mixture z k)
-> (Mean # First (Mixture z k), Mean # Second (Mixture z k))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Mean # Mixture z k
mxmdl
        (Mean # z
mls,Mean # Tensor z (Categorical k)
mlsk) = (Mean # Affine Tensor z z (Categorical k))
-> (Mean # First (Affine Tensor z z (Categorical k)),
    Mean # Second (Affine Tensor z z (Categorical k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Mean # Affine Tensor z z (Categorical k)
flsk
        mlk :: Mean # Tensor y (Categorical k)
mlk = Vector Vector k (Point Mean y) -> Mean # Tensor y (Categorical k)
forall x y c.
(Manifold x, Manifold y) =>
Vector (Dimension x) (c # y) -> c # Tensor y x
fromColumns (Vector Vector k (Point Mean y) -> Mean # Tensor y (Categorical k))
-> (Vector k (Mean # z) -> Vector Vector k (Point Mean y))
-> Vector k (Mean # z)
-> Mean # Tensor y (Categorical k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Mean # z) -> Point Mean y)
-> Vector k (Mean # z) -> Vector Vector k (Point Mean y)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Mean # z) -> Point Mean y
forall z y c. Translation z y => (c # z) -> c # y
anchor (Vector k (Mean # z) -> Mean # Tensor y (Categorical k))
-> Vector k (Mean # z) -> Mean # Tensor y (Categorical k)
forall a b. (a -> b) -> a -> b
$ (Mean # Tensor z (Categorical k))
-> Vector (Dimension (Categorical k)) (Mean # z)
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Vector (Dimension x) (c # y)
toColumns Mean # Tensor z (Categorical k)
mlsk
     in (Mean # First (AffineMixture y z k))
-> (Mean # Second (AffineMixture y z k))
-> Mean # AffineMixture y z k
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join ((Mean # First (Affine Tensor y z (Categorical k)))
-> (Mean # Second (Affine Tensor y z (Categorical k)))
-> Mean # Affine Tensor y z (Categorical k)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Mean # z
Mean # First (Affine Tensor y z (Categorical k))
mls Mean # Second (Affine Tensor y z (Categorical k))
Mean # Tensor y (Categorical k)
mlk) Mean # Second (AffineMixture y z k)
Mean # Categorical k
mk


-- Linear Gaussian Harmoniums --

linearGaussianHarmoniumConjugationParameters
    :: (KnownNat n, KnownNat k)
    => Natural # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
    -> (Double, Natural # MultivariateNormal k) -- ^ Conjugation parameters
linearGaussianHarmoniumConjugationParameters :: (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Double, Natural # MultivariateNormal k)
linearGaussianHarmoniumConjugationParameters Natural
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
aff =
    let (Natural # MultivariateNormal n
thts,Natural # Tensor (MVNMean n) (MVNMean k)
tht30) = (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Natural
    # First
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)),
    Natural
    # Second
        (Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural
# Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k)
aff
        (Vector n Double
tht1,Matrix n n Double
tht2) = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
thts
        tht3 :: Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
tht3 = (Natural # Tensor (MVNMean n) (MVNMean k))
-> Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Natural # Tensor (MVNMean n) (MVNMean k)
tht30
        ttht3 :: Matrix k n Double
ttht3 = Matrix n k Double -> Matrix k n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Matrix n m x
S.transpose Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
tht3
        itht2 :: Matrix n n Double
itht2 = Matrix n n Double -> Matrix n n Double
forall (n :: Nat) x.
(KnownNat n, Field x) =>
Matrix n n x -> Matrix n n x
S.pseudoInverse Matrix n n Double
tht2
        rho0 :: Double
rho0 = -Double
0.25 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Vector n Double
tht1 Vector n Double -> Vector n Double -> Double
forall x (n :: Nat). Numeric x => Vector n x -> Vector n x -> x
`S.dotProduct` (Matrix n n Double
itht2 Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
`S.matrixVectorMultiply` Vector n Double
tht1)
            Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double)
-> (Matrix n n Double -> Double) -> Matrix n n Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n n Double -> Double
forall (n :: Nat) x. (KnownNat n, Field x) => Matrix n n x -> x
S.determinant (Matrix n n Double -> Double)
-> (Matrix n n Double -> Matrix n n Double)
-> Matrix n n Double
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a
negate (Matrix n n Double -> Double) -> Matrix n n Double -> Double
forall a b. (a -> b) -> a -> b
$ Matrix n n Double
2Matrix n n Double -> Matrix n n Double -> Matrix n n Double
forall a. Num a => a -> a -> a
*Matrix n n Double
tht2)
        rho1 :: Vector k Double
rho1 = -Vector k Double
0.5 Vector k Double -> Vector k Double -> Vector k Double
forall a. Num a => a -> a -> a
* Matrix k n Double
ttht3 Matrix k n Double -> Vector n Double -> Vector k Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
`S.matrixVectorMultiply` (Matrix n n Double
itht2 Matrix n n Double -> Vector n Double -> Vector n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n x -> Vector m x
`S.matrixVectorMultiply` Vector n Double
tht1)
        rho2 :: Matrix k k Double
rho2 = -Matrix k k Double
0.25 Matrix k k Double -> Matrix k k Double -> Matrix k k Double
forall a. Num a => a -> a -> a
* Matrix k n Double
ttht3 Matrix k n Double -> Matrix n k Double -> Matrix k k Double
forall (m :: Nat) (n :: Nat) (o :: Nat) x.
(KnownNat m, KnownNat n, KnownNat o, Numeric x) =>
Matrix m n x -> Matrix n o x -> Matrix m o x
`S.matrixMatrixMultiply` (Matrix n n Double
itht2 Matrix n n Double -> Matrix n k Double -> Matrix n k Double
forall (m :: Nat) (n :: Nat) (o :: Nat) x.
(KnownNat m, KnownNat n, KnownNat o, Numeric x) =>
Matrix m n x -> Matrix n o x -> Matrix m o x
`S.matrixMatrixMultiply` Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
tht3)
     in (Double
rho0, Vector k Double
-> Matrix k k Double -> Natural # MultivariateNormal k
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector k Double
rho1 Matrix k k Double
rho2)

univariateToLinearGaussianHarmonium
    :: c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
    -> c # LinearGaussianHarmonium 1 1
univariateToLinearGaussianHarmonium :: (c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> c # LinearGaussianHarmonium 1 1
univariateToLinearGaussianHarmonium c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
hrm =
    let (c # Normal
z,c # Tensor NormalMean NormalMean
zx,c # Normal
x) = (c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> (c # Normal, c # Tensor NormalMean NormalMean, c # Normal)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
hrm
     in (c # MultivariateNormal 1)
-> (c # Tensor (MVNMean 1) (MVNMean 1))
-> (c # MultivariateNormal 1)
-> c # LinearGaussianHarmonium 1 1
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium ((c # Normal) -> c # MultivariateNormal 1
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # Normal
z) ((c # Tensor NormalMean NormalMean)
-> c # Tensor (MVNMean 1) (MVNMean 1)
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # Tensor NormalMean NormalMean
zx) ((c # Normal) -> c # MultivariateNormal 1
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # Normal
x)

linearGaussianHarmoniumToUnivariate
    :: c # LinearGaussianHarmonium 1 1
    -> c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
linearGaussianHarmoniumToUnivariate :: (c # LinearGaussianHarmonium 1 1)
-> c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
linearGaussianHarmoniumToUnivariate c # LinearGaussianHarmonium 1 1
hrm =
    let (c # MultivariateNormal 1
z,c # Tensor (MVNMean 1) (MVNMean 1)
zx,c # MultivariateNormal 1
x) = (c # LinearGaussianHarmonium 1 1)
-> (c # MultivariateNormal 1, c # Tensor (MVNMean 1) (MVNMean 1),
    c # MultivariateNormal 1)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # LinearGaussianHarmonium 1 1
hrm
     in (c # Normal)
-> (c # Tensor NormalMean NormalMean)
-> (c # Normal)
-> c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium ((c # MultivariateNormal 1) -> c # Normal
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # MultivariateNormal 1
z) ((c # Tensor (MVNMean 1) (MVNMean 1))
-> c # Tensor NormalMean NormalMean
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # Tensor (MVNMean 1) (MVNMean 1)
zx) ((c # MultivariateNormal 1) -> c # Normal
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint c # MultivariateNormal 1
x)

univariateToLinearModel
    :: Natural # Affine Tensor NormalMean Normal NormalMean
    -> Natural # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)
univariateToLinearModel :: (Natural # Affine Tensor NormalMean Normal NormalMean)
-> Natural
   # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)
univariateToLinearModel Natural # Affine Tensor NormalMean Normal NormalMean
aff =
    let (Natural # Normal
z,Natural # Tensor NormalMean NormalMean
zx) = (Natural # Affine Tensor NormalMean Normal NormalMean)
-> (Natural # First (Affine Tensor NormalMean Normal NormalMean),
    Natural # Second (Affine Tensor NormalMean Normal NormalMean))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # Affine Tensor NormalMean Normal NormalMean
aff
     in (Natural
 # First
     (Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)))
-> (Natural
    # Second
        (Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)))
-> Natural
   # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join ((Natural # Normal) -> Point Natural (MultivariateNormal 1)
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Natural # Normal
z) ((Natural # Tensor NormalMean NormalMean)
-> Point Natural (Tensor (MVNMean 1) (MVNMean 1))
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Natural # Tensor NormalMean NormalMean
zx)

naturalLinearGaussianHarmoniumToJoint
    :: (KnownNat n, KnownNat k)
    => Natural # LinearGaussianHarmonium n k
    -> Natural # MultivariateNormal (n+k)
naturalLinearGaussianHarmoniumToJoint :: (Natural # LinearGaussianHarmonium n k)
-> Natural # MultivariateNormal (n + k)
naturalLinearGaussianHarmoniumToJoint Natural # LinearGaussianHarmonium n k
hrm =
    let (Natural # MultivariateNormal n
z,Natural # Tensor (MVNMean n) (MVNMean k)
zx,Natural # MultivariateNormal k
x) = (Natural # LinearGaussianHarmonium n k)
-> (Natural # MultivariateNormal n,
    Natural # Tensor (MVNMean n) (MVNMean k),
    Natural # MultivariateNormal k)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium Natural # LinearGaussianHarmonium n k
hrm
        zxmtx :: Matrix n k Double
zxmtx = (Natural # Tensor (MVNMean n) (MVNMean k))
-> Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Natural # Tensor (MVNMean n) (MVNMean k)
zxMatrix n k Double -> Matrix n k Double -> Matrix n k Double
forall a. Fractional a => a -> a -> a
/Matrix n k Double
2
        mvnz :: (Vector n Double, Matrix n n Double)
mvnz = (Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal n
z
        mvnx :: (Vector k Double, Matrix k k Double)
mvnx = (Natural # MultivariateNormal k)
-> (Vector k Double, Matrix k k Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal k
x
        (Vector (n + k) Double
mu,Matrix (n + k) (n + k) Double
cvr) = (Vector n Double, Matrix n n Double)
-> Matrix n k Double
-> (Vector k Double, Matrix k k Double)
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Vector n Double, Matrix n n Double)
-> Matrix n k Double
-> (Vector k Double, Matrix k k Double)
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
fromLinearGaussianHarmonium0 (Vector n Double, Matrix n n Double)
mvnz Matrix n k Double
zxmtx (Vector k Double, Matrix k k Double)
mvnx
     in Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> Natural # MultivariateNormal (n + k)
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector (n + k) Double
mu Matrix (n + k) (n + k) Double
cvr

naturalJointToLinearGaussianHarmonium
    :: (KnownNat n, KnownNat k)
    => Natural # MultivariateNormal (n+k)
    -> Natural # LinearGaussianHarmonium n k
naturalJointToLinearGaussianHarmonium :: (Natural # MultivariateNormal (n + k))
-> Natural # LinearGaussianHarmonium n k
naturalJointToLinearGaussianHarmonium Natural # MultivariateNormal (n + k)
mvn =
    let (Vector (n + k) Double
mu,Matrix (n + k) (n + k) Double
cvr) = (Natural # MultivariateNormal (n + k))
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
forall (n :: Nat).
KnownNat n =>
(Natural # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitNaturalMultivariateNormal Natural # MultivariateNormal (n + k)
mvn
        ((Vector n Double
muz,Matrix n n Double
cvrz),Matrix n k Double
zxmtx,(Vector k Double
mux,Matrix k k Double
cvrx)) = Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> ((Vector n Double, Matrix n n Double), Matrix n k Double,
    (Vector k Double, Matrix k k Double))
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> ((Vector n Double, Matrix n n Double), Matrix n k Double,
    (Vector k Double, Matrix k k Double))
toLinearGaussianHarmonium0 Vector (n + k) Double
mu Matrix (n + k) (n + k) Double
cvr
        zx :: Natural # Tensor (MVNMean n) (MVNMean k)
zx = Natural # Tensor (MVNMean n) (MVNMean k)
2(Natural # Tensor (MVNMean n) (MVNMean k))
-> (Natural # Tensor (MVNMean n) (MVNMean k))
-> Natural # Tensor (MVNMean n) (MVNMean k)
forall a. Num a => a -> a -> a
*Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
-> Natural # Tensor (MVNMean n) (MVNMean k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
zxmtx
        z :: Natural # MultivariateNormal n
z = Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector n Double
muz Matrix n n Double
cvrz
        x :: Natural # MultivariateNormal k
x = Vector k Double
-> Matrix k k Double -> Natural # MultivariateNormal k
forall (n :: Nat).
KnownNat n =>
Vector n Double
-> Matrix n n Double -> Natural # MultivariateNormal n
joinNaturalMultivariateNormal Vector k Double
mux Matrix k k Double
cvrx
     in (Natural # MultivariateNormal n)
-> (Natural # Tensor (MVNMean n) (MVNMean k))
-> (Natural # MultivariateNormal k)
-> Natural # LinearGaussianHarmonium n k
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium Natural # MultivariateNormal n
z Natural # Tensor (MVNMean n) (MVNMean k)
zx Natural # MultivariateNormal k
x

meanLinearGaussianHarmoniumToJoint
    :: (KnownNat n, KnownNat k)
    => Mean # LinearGaussianHarmonium n k
    -> Mean # MultivariateNormal (n+k)
meanLinearGaussianHarmoniumToJoint :: (Mean # LinearGaussianHarmonium n k)
-> Mean # MultivariateNormal (n + k)
meanLinearGaussianHarmoniumToJoint Mean # LinearGaussianHarmonium n k
hrm =
    let (Mean # MultivariateNormal n
z,Mean # Tensor (MVNMean n) (MVNMean k)
zx,Mean # MultivariateNormal k
x) = (Mean # LinearGaussianHarmonium n k)
-> (Mean # MultivariateNormal n,
    Mean # Tensor (MVNMean n) (MVNMean k), Mean # MultivariateNormal k)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium Mean # LinearGaussianHarmonium n k
hrm
        zxmtx :: Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
zxmtx = (Mean # Tensor (MVNMean n) (MVNMean k))
-> Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
forall x y c.
(Manifold x, Manifold y) =>
(c # Tensor y x) -> Matrix (Dimension y) (Dimension x) Double
toMatrix Mean # Tensor (MVNMean n) (MVNMean k)
zx
        mvnz :: (Vector n Double, Matrix n n Double)
mvnz = (Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
forall (n :: Nat).
KnownNat n =>
(Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMeanMultivariateNormal Mean # MultivariateNormal n
z
        mvnx :: (Vector k Double, Matrix k k Double)
mvnx = (Mean # MultivariateNormal k)
-> (Vector k Double, Matrix k k Double)
forall (n :: Nat).
KnownNat n =>
(Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMeanMultivariateNormal Mean # MultivariateNormal k
x
        (Vector (n + k) Double
mu,Matrix (n + k) (n + k) Double
cvr) = (Vector n Double, Matrix n n Double)
-> Matrix n k Double
-> (Vector k Double, Matrix k k Double)
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Vector n Double, Matrix n n Double)
-> Matrix n k Double
-> (Vector k Double, Matrix k k Double)
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
fromLinearGaussianHarmonium0 (Vector n Double, Matrix n n Double)
mvnz Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
zxmtx (Vector k Double, Matrix k k Double)
mvnx
     in Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> Mean # MultivariateNormal (n + k)
forall (n :: Nat).
KnownNat n =>
Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
joinMeanMultivariateNormal Vector (n + k) Double
mu Matrix (n + k) (n + k) Double
cvr

meanJointToLinearGaussianHarmonium
    :: (KnownNat n, KnownNat k)
    => Mean # MultivariateNormal (n+k)
    -> Mean # LinearGaussianHarmonium n k
meanJointToLinearGaussianHarmonium :: (Mean # MultivariateNormal (n + k))
-> Mean # LinearGaussianHarmonium n k
meanJointToLinearGaussianHarmonium Mean # MultivariateNormal (n + k)
mvn =
    let (Vector (n + k) Double
mu,Matrix (n + k) (n + k) Double
cvr) = (Mean # MultivariateNormal (n + k))
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
forall (n :: Nat).
KnownNat n =>
(Mean # MultivariateNormal n)
-> (Vector n Double, Matrix n n Double)
splitMeanMultivariateNormal Mean # MultivariateNormal (n + k)
mvn
        ((Vector n Double
muz,Matrix n n Double
cvrz),Matrix n k Double
zxmtx,(Vector k Double
mux,Matrix k k Double
cvrx)) = Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> ((Vector n Double, Matrix n n Double), Matrix n k Double,
    (Vector k Double, Matrix k k Double))
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> ((Vector n Double, Matrix n n Double), Matrix n k Double,
    (Vector k Double, Matrix k k Double))
toLinearGaussianHarmonium0 Vector (n + k) Double
mu Matrix (n + k) (n + k) Double
cvr
        zx :: Mean # Tensor (MVNMean n) (MVNMean k)
zx = Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
-> Mean # Tensor (MVNMean n) (MVNMean k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix Matrix n k Double
Matrix (Dimension (MVNMean n)) (Dimension (MVNMean k)) Double
zxmtx
        z :: Mean # MultivariateNormal n
z = Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
forall (n :: Nat).
KnownNat n =>
Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
joinMeanMultivariateNormal Vector n Double
muz Matrix n n Double
cvrz
        x :: Mean # MultivariateNormal k
x = Vector k Double -> Matrix k k Double -> Mean # MultivariateNormal k
forall (n :: Nat).
KnownNat n =>
Vector n Double -> Matrix n n Double -> Mean # MultivariateNormal n
joinMeanMultivariateNormal Vector k Double
mux Matrix k k Double
cvrx
     in (Mean # MultivariateNormal n)
-> (Mean # Tensor (MVNMean n) (MVNMean k))
-> (Mean # MultivariateNormal k)
-> Mean # LinearGaussianHarmonium n k
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium Mean # MultivariateNormal n
z Mean # Tensor (MVNMean n) (MVNMean k)
zx Mean # MultivariateNormal k
x

fromLinearGaussianHarmonium0
    :: (KnownNat n, KnownNat k)
    => (S.Vector n Double, S.Matrix n n Double)
    -> S.Matrix n k Double
    -> (S.Vector k Double, S.Matrix k k Double)
    -> (S.Vector (n+k) Double, S.Matrix (n+k) (n+k) Double)
fromLinearGaussianHarmonium0 :: (Vector n Double, Matrix n n Double)
-> Matrix n k Double
-> (Vector k Double, Matrix k k Double)
-> (Vector (n + k) Double, Matrix (n + k) (n + k) Double)
fromLinearGaussianHarmonium0 (Vector n Double
muz,Matrix n n Double
cvrz) Matrix n k Double
zxmtx (Vector k Double
mux,Matrix k k Double
cvrx) =
    let mu :: Vector (n + k) Double
mu = Vector n Double
muz Vector n Double -> Vector k Double -> Vector (n + k) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Vector k Double
mux
        top :: Matrix n (n + k) Double
top = Matrix n n Double -> Matrix n k Double -> Matrix n (n + k) Double
forall (n :: Nat) (m :: Nat) (o :: Nat) x.
(KnownNat n, KnownNat m, KnownNat o, Numeric x) =>
Matrix n m x -> Matrix n o x -> Matrix n (m + o) x
S.horizontalConcat Matrix n n Double
cvrz Matrix n k Double
zxmtx
        btm :: Matrix k (n + k) Double
btm = Matrix k n Double -> Matrix k k Double -> Matrix k (n + k) Double
forall (n :: Nat) (m :: Nat) (o :: Nat) x.
(KnownNat n, KnownNat m, KnownNat o, Numeric x) =>
Matrix n m x -> Matrix n o x -> Matrix n (m + o) x
S.horizontalConcat (Matrix n k Double -> Matrix k n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Matrix n m x
S.transpose Matrix n k Double
zxmtx) Matrix k k Double
cvrx
     in (Vector (n + k) Double
mu, Matrix n (n + k) Double
-> Matrix k (n + k) Double -> Matrix (n + k) (n + k) Double
forall (n :: Nat) (m :: Nat) (o :: Nat) x.
(KnownNat n, KnownNat m, KnownNat o, Numeric x) =>
Matrix n o x -> Matrix m o x -> Matrix (n + m) o x
S.verticalConcat Matrix n (n + k) Double
top Matrix k (n + k) Double
btm)

toLinearGaussianHarmonium0
    :: (KnownNat n, KnownNat k)
    => S.Vector (n+k) Double
    -> S.Matrix (n+k) (n+k) Double
    -> ( (S.Vector n Double, S.Matrix n n Double)
       , S.Matrix n k Double
       , (S.Vector k Double, S.Matrix k k Double) )
toLinearGaussianHarmonium0 :: Vector (n + k) Double
-> Matrix (n + k) (n + k) Double
-> ((Vector n Double, Matrix n n Double), Matrix n k Double,
    (Vector k Double, Matrix k k Double))
toLinearGaussianHarmonium0 Vector (n + k) Double
mu Matrix (n + k) (n + k) Double
cvr =
    let (Vector n Double
muz,Vector k Double
mux) = Vector (n + k) Double -> (Vector n Double, Vector k Double)
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt Vector (n + k) Double
mu
        (Vector n (Vector (n + k) Double)
tops,Vector k (Vector (n + k) Double)
btms) = Vector (n + k) (Vector (n + k) Double)
-> (Vector n (Vector (n + k) Double),
    Vector k (Vector (n + k) Double))
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt (Vector (n + k) (Vector (n + k) Double)
 -> (Vector n (Vector (n + k) Double),
     Vector k (Vector (n + k) Double)))
-> Vector (n + k) (Vector (n + k) Double)
-> (Vector n (Vector (n + k) Double),
    Vector k (Vector (n + k) Double))
forall a b. (a -> b) -> a -> b
$ Matrix (n + k) (n + k) Double
-> Vector (n + k) (Vector (n + k) Double)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Storable x) =>
Matrix m n x -> Vector m (Vector n x)
S.toRows Matrix (n + k) (n + k) Double
cvr
        (Vector n (Vector n Double)
cvrzs,Vector k (Vector n Double)
zxmtxs) = Vector (n + k) (Vector n Double)
-> (Vector n (Vector n Double), Vector k (Vector n Double))
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt (Vector (n + k) (Vector n Double)
 -> (Vector n (Vector n Double), Vector k (Vector n Double)))
-> (Matrix n (n + k) Double -> Vector (n + k) (Vector n Double))
-> Matrix n (n + k) Double
-> (Vector n (Vector n Double), Vector k (Vector n Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix n (n + k) Double -> Vector (n + k) (Vector n Double)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n (Vector m x)
S.toColumns (Matrix n (n + k) Double
 -> (Vector n (Vector n Double), Vector k (Vector n Double)))
-> Matrix n (n + k) Double
-> (Vector n (Vector n Double), Vector k (Vector n Double))
forall a b. (a -> b) -> a -> b
$ Vector n (Vector (n + k) Double) -> Matrix n (n + k) Double
forall (n :: Nat) x (m :: Nat).
(KnownNat n, Storable x) =>
Vector m (Vector n x) -> Matrix m n x
S.fromRows Vector n (Vector (n + k) Double)
tops
        cvrz :: Matrix n n Double
cvrz = Vector n (Vector n Double) -> Matrix n n Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector n (Vector m x) -> Matrix m n x
S.fromColumns Vector n (Vector n Double)
cvrzs
        zxmtx :: Matrix n k Double
zxmtx = Vector k (Vector n Double) -> Matrix n k Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector n (Vector m x) -> Matrix m n x
S.fromColumns Vector k (Vector n Double)
zxmtxs
        cvrx :: Matrix k k Double
cvrx = Vector k (Vector k Double) -> Matrix k k Double
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Vector n (Vector m x) -> Matrix m n x
S.fromColumns (Vector k (Vector k Double) -> Matrix k k Double)
-> (Matrix k (n + k) Double -> Vector k (Vector k Double))
-> Matrix k (n + k) Double
-> Matrix k k Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (n + k) (Vector k Double) -> Vector k (Vector k Double)
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> Vector m a
S.drop (Vector (n + k) (Vector k Double) -> Vector k (Vector k Double))
-> (Matrix k (n + k) Double -> Vector (n + k) (Vector k Double))
-> Matrix k (n + k) Double
-> Vector k (Vector k Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix k (n + k) Double -> Vector (n + k) (Vector k Double)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n, Numeric x) =>
Matrix m n x -> Vector n (Vector m x)
S.toColumns (Matrix k (n + k) Double -> Matrix k k Double)
-> Matrix k (n + k) Double -> Matrix k k Double
forall a b. (a -> b) -> a -> b
$ Vector k (Vector (n + k) Double) -> Matrix k (n + k) Double
forall (n :: Nat) x (m :: Nat).
(KnownNat n, Storable x) =>
Vector m (Vector n x) -> Matrix m n x
S.fromRows Vector k (Vector (n + k) Double)
btms
     in ((Vector n Double
muz,Matrix n n Double
cvrz),Matrix n k Double
zxmtx,(Vector k Double
mux,Matrix k k Double
cvrx))

harmoniumLogBaseMeasure
    :: forall f y x z w . (ExponentialFamily z, ExponentialFamily w)
    => Proxy (AffineHarmonium f y x z w)
    -> SamplePoint (z,w)
    -> Double
harmoniumLogBaseMeasure :: Proxy (AffineHarmonium f y x z w) -> SamplePoint (z, w) -> Double
harmoniumLogBaseMeasure Proxy (AffineHarmonium f y x z w)
_ (z,w) =
    Proxy z -> SamplePoint z -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy z
forall k (t :: k). Proxy t
Proxy @ z) SamplePoint z
z Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Proxy w -> SamplePoint w -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy w
forall k (t :: k). Proxy t
Proxy @ w) SamplePoint w
w


--- Instances ---


instance Manifold (AffineHarmonium f y x z w) => Statistical (AffineHarmonium f y x z w) where
    type SamplePoint (AffineHarmonium f y x z w) = SamplePoint (z,w)

instance ( ExponentialFamily z, ExponentialFamily x, Translation z y
         , Translation w x
         , ExponentialFamily w, ExponentialFamily y, Bilinear f y x )
  => ExponentialFamily (AffineHarmonium f y x z w) where
      sufficientStatistic :: SamplePoint (AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
sufficientStatistic (z,w) =
          let mz :: Mean # z
mz = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint z
z
              mw :: Mean # w
mw = SamplePoint w -> Mean # w
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint w
w
              my :: Mean # y
my = (Mean # z) -> Mean # y
forall z y c. Translation z y => (c # z) -> c # y
anchor Mean # z
mz
              mx :: Mean # x
mx = (Mean # w) -> Mean # x
forall z y c. Translation z y => (c # z) -> c # y
anchor Mean # w
mw
           in (Mean # z)
-> (Mean # f y x) -> (Mean # w) -> Mean # AffineHarmonium f y x z w
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium Mean # z
mz (Mean # y
my (Mean # y) -> (Mean # x) -> Mean # f y x
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
(c # y) -> (c # x) -> c # f y x
>.< Mean # x
mx) Mean # w
mw
      averageSufficientStatistic :: Sample (AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
averageSufficientStatistic Sample (AffineHarmonium f y x z w)
zws =
          let ([SamplePoint z]
zs,[SamplePoint w]
ws) = [(SamplePoint z, SamplePoint w)]
-> ([SamplePoint z], [SamplePoint w])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SamplePoint z, SamplePoint w)]
Sample (AffineHarmonium f y x z w)
zws
              mzs :: [Mean # z]
mzs = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint z -> Mean # z) -> [SamplePoint z] -> [Mean # z]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SamplePoint z]
zs
              mws :: [Mean # w]
mws = SamplePoint w -> Mean # w
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint w -> Mean # w) -> [SamplePoint w] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SamplePoint w]
ws
              mys :: [Mean # y]
mys = (Mean # z) -> Mean # y
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # z) -> Mean # y) -> [Mean # z] -> [Mean # y]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # z]
mzs
              mxs :: [Mean # x]
mxs = (Mean # w) -> Mean # x
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # w) -> Mean # x) -> [Mean # w] -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # w]
mws
           in (Mean # z)
-> (Mean # f y x) -> (Mean # w) -> Mean # AffineHarmonium f y x z w
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium ([Mean # z] -> Mean # z
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # z]
mzs) ([Mean # y]
mys [Mean # y] -> [Mean # x] -> Mean # f y x
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
[c # y] -> [c # x] -> c # f y x
>$< [Mean # x]
mxs) ([Mean # w] -> Mean # w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # w]
mws)
      logBaseMeasure :: Proxy (AffineHarmonium f y x z w)
-> SamplePoint (AffineHarmonium f y x z w) -> Double
logBaseMeasure = Proxy (AffineHarmonium f y x z w)
-> SamplePoint (AffineHarmonium f y x z w) -> Double
forall (f :: Type -> Type -> Type) y x z w.
(ExponentialFamily z, ExponentialFamily w) =>
Proxy (AffineHarmonium f y x z w) -> SamplePoint (z, w) -> Double
harmoniumLogBaseMeasure

instance ( KnownNat k, LegendreExponentialFamily z
         , Translation y z, LegendreExponentialFamily y, SamplePoint z ~ SamplePoint y)
  => ConjugatedLikelihood Tensor z (Categorical k) y (Categorical k) where
    conjugationParameters :: (Natural # Affine Tensor z y (Categorical k))
-> (Double, Natural # Categorical k)
conjugationParameters = (Natural # Affine Tensor z y (Categorical k))
-> (Double, Natural # Categorical k)
forall (k :: Nat) z y.
(KnownNat k, LegendreExponentialFamily z, Translation z y) =>
(Natural # Affine Tensor y z (Categorical k))
-> (Double, Natural # Categorical k)
mixtureLikelihoodConjugationParameters

instance ConjugatedLikelihood Tensor NormalMean NormalMean Normal Normal where
    conjugationParameters :: (Natural # Affine Tensor NormalMean Normal NormalMean)
-> (Double, Natural # Normal)
conjugationParameters Natural # Affine Tensor NormalMean Normal NormalMean
aff =
        let rprms :: Natural # MultivariateNormal 1
            (Double
rho0,Point Natural (MultivariateNormal 1)
rprms) = (Natural
 # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1))
-> (Double, Point Natural (MultivariateNormal 1))
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x) -> (Double, Natural # w)
conjugationParameters ((Natural
  # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1))
 -> (Double, Point Natural (MultivariateNormal 1)))
-> (Natural
    # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1))
-> (Double, Point Natural (MultivariateNormal 1))
forall a b. (a -> b) -> a -> b
$ (Natural # Affine Tensor NormalMean Normal NormalMean)
-> Natural
   # Affine Tensor (MVNMean 1) (MultivariateNormal 1) (MVNMean 1)
univariateToLinearModel Natural # Affine Tensor NormalMean Normal NormalMean
aff
         in (Double
rho0,Point Natural (MultivariateNormal 1) -> Natural # Normal
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint Point Natural (MultivariateNormal 1)
rprms)

instance (KnownNat n, KnownNat k) => ConjugatedLikelihood Tensor (MVNMean n) (MVNMean k)
    (MultivariateNormal n) (MultivariateNormal k) where
        conjugationParameters :: (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Double, Natural # MultivariateNormal k)
conjugationParameters = (Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Double, Natural # MultivariateNormal k)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Natural
 # Affine Tensor (MVNMean n) (MultivariateNormal n) (MVNMean k))
-> (Double, Natural # MultivariateNormal k)
linearGaussianHarmoniumConjugationParameters

--instance ( KnownNat k, LegendreExponentialFamily z
--         , Generative Natural z, Manifold (Mixture z k) )
--         => Generative Natural (Mixture z k) where
--    sample = sampleConjugated

--instance (KnownNat k, LegendreExponentialFamily z)
--  => Transition Natural Mean (Mixture z k) where
--    transition nhrm =
--        let (nzs,nx) = splitNaturalMixture nhrm
--            mx = toMean nx
--            mzs = S.map transition nzs
--         in joinMeanMixture mzs mx

instance ( KnownNat k, Manifold y, Manifold z
         , LegendreExponentialFamily z, Translation z y )
  => Transition Natural Mean (AffineMixture y z k) where
    transition :: (Natural # AffineMixture y z k) -> Mean # AffineMixture y z k
transition Natural # AffineMixture y z k
mxmdl0 =
        let mxmdl :: Natural # Mixture z k
mxmdl = (Natural # AffineMixture y z k) -> Natural # Mixture z k
forall (k :: Nat) z y.
(KnownNat k, Manifold z, Manifold y, Translation z y) =>
(Natural # AffineMixture y z k) -> Natural # Mixture z k
affineMixtureToMixture Natural # AffineMixture y z k
mxmdl0
            (Vector (k + 1) (Natural # z)
nzs,Natural # Categorical k
nx) = (Natural # Mixture z k)
-> (Vector (k + 1) (Natural # z), Natural # Categorical k)
forall (k :: Nat) z.
(KnownNat k, LegendreExponentialFamily z) =>
(Natural # Mixture z k)
-> (Vector (k + 1) (Natural # z), Natural # Categorical k)
splitNaturalMixture Natural # Mixture z k
mxmdl
            mx :: Mean # Categorical k
mx = (Natural # Categorical k) -> Mean # Categorical k
forall c x. Transition c Mean x => (c # x) -> Mean # x
toMean Natural # Categorical k
nx
            mzs :: Vector (k + 1) (Mean # z)
mzs = ((Natural # z) -> Mean # z)
-> Vector (k + 1) (Natural # z) -> Vector (k + 1) (Mean # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Natural # z) -> Mean # z
forall c d x. Transition c d x => (c # x) -> d # x
transition Vector (k + 1) (Natural # z)
nzs
         in (Mean # Mixture z k) -> Mean # AffineMixture y z k
forall (k :: Nat) y z.
(KnownNat k, Manifold y, Manifold z, Translation z y) =>
(Mean # Mixture z k) -> Mean # AffineMixture y z k
mixtureToAffineMixture ((Mean # Mixture z k) -> Mean # AffineMixture y z k)
-> (Mean # Mixture z k) -> Mean # AffineMixture y z k
forall a b. (a -> b) -> a -> b
$ Vector (k + 1) (Mean # z)
-> (Mean # Categorical k) -> Mean # Mixture z k
forall (k :: Nat) z.
(KnownNat k, Manifold z) =>
Vector (k + 1) (Mean # z)
-> (Mean # Categorical k) -> Mean # Mixture z k
joinMeanMixture Vector (k + 1) (Mean # z)
mzs Mean # Categorical k
mx

instance ( KnownNat k, Manifold y, Manifold z, LegendreExponentialFamily z
         , Generative Natural z, Translation z y )
  => Generative Natural (AffineMixture y z k) where
      sample :: Int
-> Point Natural (AffineMixture y z k)
-> Random (Sample (AffineMixture y z k))
sample Int
n = Int
-> (Natural
    # AffineHarmonium Tensor z (Categorical k) z (Categorical k))
-> Random (Sample (z, Categorical k))
forall (f :: Type -> Type -> Type) y x z w.
(ConjugatedLikelihood f y x z w, Generative Natural w,
 Generative Natural z, Map Natural f y x) =>
Int
-> (Natural # AffineHarmonium f y x z w) -> Random (Sample (z, w))
sampleConjugated Int
n ((Natural
  # AffineHarmonium Tensor z (Categorical k) z (Categorical k))
 -> Random [(SamplePoint z, Int)])
-> (Point Natural (AffineMixture y z k)
    -> Natural
       # AffineHarmonium Tensor z (Categorical k) z (Categorical k))
-> Point Natural (AffineMixture y z k)
-> Random [(SamplePoint z, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point Natural (AffineMixture y z k)
-> Natural
   # AffineHarmonium Tensor z (Categorical k) z (Categorical k)
forall (k :: Nat) z y.
(KnownNat k, Manifold z, Manifold y, Translation z y) =>
(Natural # AffineMixture y z k) -> Natural # Mixture z k
affineMixtureToMixture

instance (KnownNat k, DuallyFlatExponentialFamily z)
  => Transition Mean Natural (Mixture z k) where
    transition :: (Mean # Mixture z k) -> Natural # Mixture z k
transition Mean # Mixture z k
mhrm =
        let (Vector (k + 1) (Mean # z)
mzs,Mean # Categorical k
mx) = (Mean # Mixture z k)
-> (Vector (k + 1) (Mean # z), Mean # Categorical k)
forall (k :: Nat) z.
(KnownNat k, DuallyFlatExponentialFamily z) =>
(Mean # Mixture z k)
-> (Vector (k + 1) (Mean # z), Mean # Categorical k)
splitMeanMixture Mean # Mixture z k
mhrm
            nx :: Natural # Categorical k
nx = (Mean # Categorical k) -> Natural # Categorical k
forall c d x. Transition c d x => (c # x) -> d # x
transition Mean # Categorical k
mx
            nzs :: Vector (k + 1) (Natural # z)
nzs = ((Mean # z) -> Natural # z)
-> Vector (k + 1) (Mean # z) -> Vector (k + 1) (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Mean # z) -> Natural # z
forall c d x. Transition c d x => (c # x) -> d # x
transition Vector (k + 1) (Mean # z)
mzs
         in Vector (k + 1) (Natural # z)
-> (Natural # Categorical k) -> Natural # Mixture z k
forall (k :: Nat) z.
(KnownNat k, LegendreExponentialFamily z) =>
Vector (k + 1) (Natural # z)
-> (Natural # Categorical k) -> Natural # Mixture z k
joinNaturalMixture Vector (k + 1) (Natural # z)
nzs Natural # Categorical k
nx

instance (KnownNat k, LegendreExponentialFamily z, Transition Natural Source z)
  => Transition Natural Source (Mixture z k) where
    transition :: (Natural # Mixture z k) -> Source # Mixture z k
transition Natural # Mixture z k
nhrm =
        let (Vector (k + 1) (Natural # z)
nzs,Natural # Categorical k
nx) = (Natural # Mixture z k)
-> (Vector (k + 1) (Natural # z), Natural # Categorical k)
forall (k :: Nat) z.
(KnownNat k, LegendreExponentialFamily z) =>
(Natural # Mixture z k)
-> (Vector (k + 1) (Natural # z), Natural # Categorical k)
splitNaturalMixture Natural # Mixture z k
nhrm
            sx :: Source # Categorical k
sx = (Natural # Categorical k) -> Source # Categorical k
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # Categorical k
nx
            szs :: Vector (k + 1) (Source # z)
szs = ((Natural # z) -> Source # z)
-> Vector (k + 1) (Natural # z) -> Vector (k + 1) (Source # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Natural # z) -> Source # z
forall c d x. Transition c d x => (c # x) -> d # x
transition Vector (k + 1) (Natural # z)
nzs
         in Vector (k + 1) (Source # z)
-> (Source # Categorical k) -> Source # Mixture z k
forall (k :: Nat) z.
(KnownNat k, Manifold z) =>
Vector (k + 1) (Source # z)
-> (Source # Categorical k) -> Source # Mixture z k
joinSourceMixture Vector (k + 1) (Source # z)
szs Source # Categorical k
sx

instance (KnownNat k, LegendreExponentialFamily z, Transition Source Natural z)
  => Transition Source Natural (Mixture z k) where
    transition :: (Source # Mixture z k) -> Natural # Mixture z k
transition Source # Mixture z k
shrm =
        let (Vector (k + 1) (Source # z)
szs,Source # Categorical k
sx) = (Source # Mixture z k)
-> (Vector (k + 1) (Source # z), Source # Categorical k)
forall (k :: Nat) z.
(KnownNat k, Manifold z) =>
(Source # Mixture z k)
-> (Vector (k + 1) (Source # z), Source # Categorical k)
splitSourceMixture Source # Mixture z k
shrm
            nx :: Natural # Categorical k
nx = (Source # Categorical k) -> Natural # Categorical k
forall c d x. Transition c d x => (c # x) -> d # x
transition Source # Categorical k
sx
            nzs :: Vector (k + 1) (Natural # z)
nzs = ((Source # z) -> Natural # z)
-> Vector (k + 1) (Source # z) -> Vector (k + 1) (Natural # z)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Source # z) -> Natural # z
forall c d x. Transition c d x => (c # x) -> d # x
transition Vector (k + 1) (Source # z)
szs
         in Vector (k + 1) (Natural # z)
-> (Natural # Categorical k) -> Natural # Mixture z k
forall (k :: Nat) z.
(KnownNat k, LegendreExponentialFamily z) =>
Vector (k + 1) (Natural # z)
-> (Natural # Categorical k) -> Natural # Mixture z k
joinNaturalMixture Vector (k + 1) (Natural # z)
nzs Natural # Categorical k
nx

instance Transition Natural Mean
  (AffineHarmonium Tensor NormalMean NormalMean Normal Normal) where
      transition :: (Natural
 # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Mean
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
transition = (Mean # LinearGaussianHarmonium 1 1)
-> Mean
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
forall c.
(c # LinearGaussianHarmonium 1 1)
-> c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
linearGaussianHarmoniumToUnivariate ((Mean # LinearGaussianHarmonium 1 1)
 -> Mean
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> ((Natural
     # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
    -> Mean # LinearGaussianHarmonium 1 1)
-> (Natural
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Mean
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # LinearGaussianHarmonium 1 1)
-> Mean # LinearGaussianHarmonium 1 1
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Natural # LinearGaussianHarmonium 1 1)
 -> Mean # LinearGaussianHarmonium 1 1)
-> ((Natural
     # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
    -> Natural # LinearGaussianHarmonium 1 1)
-> (Natural
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Mean # LinearGaussianHarmonium 1 1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural
 # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Natural # LinearGaussianHarmonium 1 1
forall c.
(c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> c # LinearGaussianHarmonium 1 1
univariateToLinearGaussianHarmonium

instance Transition Mean Natural
  (AffineHarmonium Tensor NormalMean NormalMean Normal Normal) where
      transition :: (Mean # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Natural
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
transition =  (Natural # LinearGaussianHarmonium 1 1)
-> Natural
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
forall c.
(c # LinearGaussianHarmonium 1 1)
-> c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
linearGaussianHarmoniumToUnivariate ((Natural # LinearGaussianHarmonium 1 1)
 -> Natural
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> ((Mean
     # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
    -> Natural # LinearGaussianHarmonium 1 1)
-> (Mean
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Natural
   # AffineHarmonium Tensor NormalMean NormalMean Normal Normal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # LinearGaussianHarmonium 1 1)
-> Natural # LinearGaussianHarmonium 1 1
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Mean # LinearGaussianHarmonium 1 1)
 -> Natural # LinearGaussianHarmonium 1 1)
-> ((Mean
     # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
    -> Mean # LinearGaussianHarmonium 1 1)
-> (Mean
    # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Natural # LinearGaussianHarmonium 1 1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> Mean # LinearGaussianHarmonium 1 1
forall c.
(c # AffineHarmonium Tensor NormalMean NormalMean Normal Normal)
-> c # LinearGaussianHarmonium 1 1
univariateToLinearGaussianHarmonium

instance (KnownNat n, KnownNat k) => Transition Natural Mean
  (AffineHarmonium Tensor (MVNMean n) (MVNMean k)
    (MultivariateNormal n) (MultivariateNormal k)) where
      transition :: (Natural
 # AffineHarmonium
     Tensor
     (MVNMean n)
     (MVNMean k)
     (MultivariateNormal n)
     (MultivariateNormal k))
-> Mean
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
transition = (Mean # MultivariateNormal (n + k))
-> Mean
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Mean # MultivariateNormal (n + k))
-> Mean # LinearGaussianHarmonium n k
meanJointToLinearGaussianHarmonium ((Mean # MultivariateNormal (n + k))
 -> Mean
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> ((Natural
     # AffineHarmonium
         Tensor
         (MVNMean n)
         (MVNMean k)
         (MultivariateNormal n)
         (MultivariateNormal k))
    -> Mean # MultivariateNormal (n + k))
-> (Natural
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> Mean
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # MultivariateNormal (n + k))
-> Mean # MultivariateNormal (n + k)
forall c d x. Transition c d x => (c # x) -> d # x
transition
        ((Natural # MultivariateNormal (n + k))
 -> Mean # MultivariateNormal (n + k))
-> ((Natural
     # AffineHarmonium
         Tensor
         (MVNMean n)
         (MVNMean k)
         (MultivariateNormal n)
         (MultivariateNormal k))
    -> Natural # MultivariateNormal (n + k))
-> (Natural
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> Mean # MultivariateNormal (n + k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural
 # AffineHarmonium
     Tensor
     (MVNMean n)
     (MVNMean k)
     (MultivariateNormal n)
     (MultivariateNormal k))
-> Natural # MultivariateNormal (n + k)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Natural # LinearGaussianHarmonium n k)
-> Natural # MultivariateNormal (n + k)
naturalLinearGaussianHarmoniumToJoint

instance (KnownNat n, KnownNat k) => Transition Mean Natural
  (AffineHarmonium Tensor (MVNMean n) (MVNMean k)
    (MultivariateNormal n) (MultivariateNormal k)) where
      transition :: (Mean
 # AffineHarmonium
     Tensor
     (MVNMean n)
     (MVNMean k)
     (MultivariateNormal n)
     (MultivariateNormal k))
-> Natural
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
transition = (Natural # MultivariateNormal (n + k))
-> Natural
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Natural # MultivariateNormal (n + k))
-> Natural # LinearGaussianHarmonium n k
naturalJointToLinearGaussianHarmonium ((Natural # MultivariateNormal (n + k))
 -> Natural
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> ((Mean
     # AffineHarmonium
         Tensor
         (MVNMean n)
         (MVNMean k)
         (MultivariateNormal n)
         (MultivariateNormal k))
    -> Natural # MultivariateNormal (n + k))
-> (Mean
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> Natural
   # AffineHarmonium
       Tensor
       (MVNMean n)
       (MVNMean k)
       (MultivariateNormal n)
       (MultivariateNormal k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # MultivariateNormal (n + k))
-> Natural # MultivariateNormal (n + k)
forall c d x. Transition c d x => (c # x) -> d # x
transition
        ((Mean # MultivariateNormal (n + k))
 -> Natural # MultivariateNormal (n + k))
-> ((Mean
     # AffineHarmonium
         Tensor
         (MVNMean n)
         (MVNMean k)
         (MultivariateNormal n)
         (MultivariateNormal k))
    -> Mean # MultivariateNormal (n + k))
-> (Mean
    # AffineHarmonium
        Tensor
        (MVNMean n)
        (MVNMean k)
        (MultivariateNormal n)
        (MultivariateNormal k))
-> Natural # MultivariateNormal (n + k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean
 # AffineHarmonium
     Tensor
     (MVNMean n)
     (MVNMean k)
     (MultivariateNormal n)
     (MultivariateNormal k))
-> Mean # MultivariateNormal (n + k)
forall (n :: Nat) (k :: Nat).
(KnownNat n, KnownNat k) =>
(Mean # LinearGaussianHarmonium n k)
-> Mean # MultivariateNormal (n + k)
meanLinearGaussianHarmoniumToJoint

--type instance PotentialCoordinates (Mixture z k) = Natural
--
--instance (KnownNat k, LegendreExponentialFamily z) => Legendre (Mixture z k) where
--      potential = conjugatedPotential

type instance PotentialCoordinates (AffineHarmonium f y x z w) = Natural

instance ( Manifold (f y x), LegendreExponentialFamily w, ConjugatedLikelihood f y x z w )
  => Legendre (AffineHarmonium f y x z w) where
      potential :: (PotentialCoordinates (AffineHarmonium f y x z w)
 # AffineHarmonium f y x z w)
-> Double
potential = (PotentialCoordinates (AffineHarmonium f y x z w)
 # AffineHarmonium f y x z w)
-> Double
forall w (f :: Type -> Type -> Type) y x z.
(LegendreExponentialFamily w, ConjugatedLikelihood f y x z w) =>
(Natural # AffineHarmonium f y x z w) -> Double
conjugatedPotential

instance ( Manifold (f y x), LegendreExponentialFamily w
         , Transition Mean Natural (AffineHarmonium f y x z w), ConjugatedLikelihood f y x z w )
  => DuallyFlat (AffineHarmonium f y x z w) where
    dualPotential :: (PotentialCoordinates (AffineHarmonium f y x z w)
 #* AffineHarmonium f y x z w)
-> Double
dualPotential PotentialCoordinates (AffineHarmonium f y x z w)
#* AffineHarmonium f y x z w
mhrm =
        let nhrm :: Natural # AffineHarmonium f y x z w
nhrm = (Mean # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural PotentialCoordinates (AffineHarmonium f y x z w)
#* AffineHarmonium f y x z w
Mean # AffineHarmonium f y x z w
mhrm
         in PotentialCoordinates (AffineHarmonium f y x z w)
#* AffineHarmonium f y x z w
Mean # AffineHarmonium f y x z w
mhrm (Mean # AffineHarmonium f y x z w)
-> (Mean #* AffineHarmonium f y x z w) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* AffineHarmonium f y x z w
Natural # AffineHarmonium f y x z w
nhrm Double -> Double -> Double
forall a. Num a => a -> a -> a
- (PotentialCoordinates (AffineHarmonium f y x z w)
 # AffineHarmonium f y x z w)
-> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates (AffineHarmonium f y x z w)
# AffineHarmonium f y x z w
Natural # AffineHarmonium f y x z w
nhrm

instance ( Bilinear f y x, ExponentialFamily y, ExponentialFamily x
         , LegendreExponentialFamily w, ConjugatedLikelihood f y x z w )
  => AbsolutelyContinuous Natural (AffineHarmonium f y x z w) where
    logDensities :: Point Natural (AffineHarmonium f y x z w)
-> Sample (AffineHarmonium f y x z w) -> [Double]
logDensities = Point Natural (AffineHarmonium f y x z w)
-> Sample (AffineHarmonium f y x z w) -> [Double]
forall x.
(ExponentialFamily x, Legendre x,
 PotentialCoordinates x ~ Natural) =>
(Natural # x) -> Sample x -> [Double]
exponentialFamilyLogDensities

instance ( ConjugatedLikelihood f y x z w, LegendreExponentialFamily z
         , ExponentialFamily y, LegendreExponentialFamily w
         , Map Natural f x y, Bilinear f x y )
  => ObservablyContinuous Natural (AffineHarmonium f y x z w) where
    logObservableDensities :: (Natural # AffineHarmonium f y x z w)
-> Observations (AffineHarmonium f y x z w) -> [Double]
logObservableDensities Natural # AffineHarmonium f y x z w
hrm Observations (AffineHarmonium f y x z w)
zs =
        let rho0rprms :: (Double, Natural # w)
rho0rprms = (Natural # AffineHarmonium f y x z w) -> (Double, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # AffineHarmonium f y x z w) -> (Double, Natural # w)
harmoniumConjugationParameters Natural # AffineHarmonium f y x z w
hrm
         in (Double, Natural # w)
-> (Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
forall (f :: Type -> Type -> Type) x y z w.
(Bilinear f y x, Translation z y, LegendreExponentialFamily z,
 ExponentialFamily y, LegendreExponentialFamily w, Translation w x,
 Map Natural f x y) =>
(Double, Natural # w)
-> (Natural # AffineHarmonium f y x z w) -> Sample z -> [Double]
logConjugatedDensities (Double, Natural # w)
rho0rprms Natural # AffineHarmonium f y x z w
hrm Observations (AffineHarmonium f y x z w)
Sample z
zs

instance ( LegendreExponentialFamily z, LegendreExponentialFamily w
         , ConjugatedLikelihood f y x z w, Map Natural f x y, Bilinear f x y
         , LegendreExponentialFamily (AffineHarmonium f y x z w)
         , Manifold (f y x), SamplePoint z ~ t, ExponentialFamily y)
  => LogLikelihood Natural (AffineHarmonium f y x z w) t where
    logLikelihood :: [t] -> (Natural # AffineHarmonium f y x z w) -> Double
logLikelihood [t]
xs Natural # AffineHarmonium f y x z w
hrm =
         [Double] -> Double
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium f y x z w)
-> Observations (AffineHarmonium f y x z w) -> [Double]
forall c f.
ObservablyContinuous c f =>
(c # f) -> Observations f -> [Double]
logObservableDensities Natural # AffineHarmonium f y x z w
hrm [t]
Observations (AffineHarmonium f y x z w)
xs
    logLikelihoodDifferential :: [t]
-> (Natural # AffineHarmonium f y x z w)
-> Natural #* AffineHarmonium f y x z w
logLikelihoodDifferential [t]
zs Natural # AffineHarmonium f y x z w
hrm =
        let pxs :: Mean # AffineHarmonium f y x z w
pxs = Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Bilinear f y x,
 Translation z y, Translation w x, LegendreExponentialFamily w) =>
Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
expectationStep [t]
Sample z
zs Natural # AffineHarmonium f y x z w
hrm
            qxs :: Mean # AffineHarmonium f y x z w
qxs = (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # AffineHarmonium f y x z w
hrm
         in Mean # AffineHarmonium f y x z w
pxs (Mean # AffineHarmonium f y x z w)
-> (Mean # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall a. Num a => a -> a -> a
- Mean # AffineHarmonium f y x z w
qxs

instance ( Translation z y, Manifold w, Manifold (f y x) )
  => Translation (AffineHarmonium f y x z w) y where
      >+> :: (c # AffineHarmonium f y x z w)
-> (c # y) -> c # AffineHarmonium f y x z w
(>+>) c # AffineHarmonium f y x z w
hrm c # y
ny =
          let (c # z
nz,c # f y x
nyx,c # w
nw) = (c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # AffineHarmonium f y x z w
hrm
           in (c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium (c # z
nz (c # z) -> (c # y) -> c # z
forall z y c. Translation z y => (c # z) -> (c # y) -> c # z
>+> c # y
ny) c # f y x
nyx c # w
nw
      anchor :: (c # AffineHarmonium f y x z w) -> c # y
anchor c # AffineHarmonium f y x z w
hrm =
          let (c # z
nz,c # f y x
_,c # w
_) = (c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
forall z (f :: Type -> Type -> Type) y x w c.
(Manifold z, Manifold (f y x), Manifold w) =>
(c # AffineHarmonium f y x z w) -> (c # z, c # f y x, c # w)
splitHarmonium c # AffineHarmonium f y x z w
hrm
           in (c # z) -> c # y
forall z y c. Translation z y => (c # z) -> c # y
anchor c # z
nz