{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
{-# LANGUAGE Arrows #-}
-- | A collection of algorithms for optimizing harmoniums.

module Goal.Graphical.Learning
    ( -- * Expectation Maximization
      expectationMaximization
    , expectationMaximizationAscent
    , gibbsExpectationMaximization
    , latentProcessExpectationMaximization
    , latentProcessExpectationMaximizationAscent
    -- * Differentials
    , harmoniumInformationProjectionDifferential
    , contrastiveDivergence
    ) where


--- Imports ---


-- Goal --

import Goal.Core
import Goal.Geometry
import Goal.Probability

import Goal.Graphical.Models
import Goal.Graphical.Models.Harmonium
import Goal.Graphical.Models.Dynamic


--- Differentials ---


-- | The differential of the dual relative entropy. Minimizing this results in
-- the information projection of the model against the marginal distribution of
-- the given harmonium. This is more efficient than the generic version.
harmoniumInformationProjectionDifferential
    :: ( Map Natural f y x, LegendreExponentialFamily z
       , SamplePoint w ~ SamplePoint x, Translation z y
       , ExponentialFamily x, ExponentialFamily w, Generative Natural w )
    => Int
    -> Natural # AffineHarmonium f y x z w -- ^ Harmonium
    -> Natural # w -- ^ Model Distribution
    -> Random (Mean # w) -- ^ Differential Estimate
harmoniumInformationProjectionDifferential :: Int
-> (Natural # AffineHarmonium f y x z w)
-> (Natural # w)
-> Random (Mean # w)
harmoniumInformationProjectionDifferential Int
n Natural # AffineHarmonium f y x z w
hrm Natural # w
px = do
    [SamplePoint x]
xs <- Int -> (Natural # w) -> Random (Sample w)
forall c x. Generative c x => Int -> Point c x -> Random (Sample x)
sample Int
n Natural # w
px
    let (Natural # Affine f y z x
lkl,Natural # w
nw) = (Natural # AffineHarmonium f y x z w)
-> (Natural # First (AffineHarmonium f y x z w),
    Natural # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Natural # AffineHarmonium f y x z w
hrm
        mys0 :: [Natural # z]
mys0 = Natural # Affine f y z x
lkl (Natural # Affine f y z x) -> [SamplePoint x] -> [Natural # z]
forall (f :: Type -> Type -> Type) y x.
(Map Natural f y x, ExponentialFamily x) =>
(Natural # f y x) -> Sample x -> [Natural # y]
>$>* [SamplePoint x]
xs
        mws :: [Mean # w]
mws = SamplePoint x -> Mean # w
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # w) -> [SamplePoint x] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SamplePoint x]
xs
        mys :: [Double]
mys = ((Mean # w) -> (Natural # z) -> Double)
-> [Mean # w] -> [Natural # z] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Mean # w
mw Natural # z
my0 -> Mean # w
mw (Mean # w) -> (Mean #* w) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> (Natural # w
px (Natural # w) -> (Natural # w) -> Natural # w
forall a. Num a => a -> a -> a
- Natural # w
nw) Double -> Double -> Double
forall a. Num a => a -> a -> a
- (PotentialCoordinates z # z) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates z # z
Natural # z
my0) [Mean # w]
mws [Natural # z]
mys0
        ln :: Double
ln = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ [SamplePoint x] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [SamplePoint x]
xs
        mwht :: Mean # w
mwht = [Mean # w] -> Mean # w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # w]
mws
        myht :: Double
myht = [Double] -> Double
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [Double]
mys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
ln
        foldfun :: (Mean # w, Double) -> (Double, Mean # w) -> (Double, Mean # w)
foldfun (Mean # w
mw,Double
my) (Double
k,Mean # w
z0) = (Double
kDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
1,Mean # w
z0 (Mean # w) -> (Mean # w) -> Mean # w
forall a. Num a => a -> a -> a
+ ((Double
my Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
myht) Double -> (Mean # w) -> Mean # w
forall c x. Double -> (c # x) -> c # x
.> (Mean # w
mw (Mean # w) -> (Mean # w) -> Mean # w
forall a. Num a => a -> a -> a
- Mean # w
mwht)))
    (Mean # w) -> Random (Mean # w)
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Mean # w) -> Random (Mean # w))
-> ([(Mean # w, Double)] -> Mean # w)
-> [(Mean # w, Double)]
-> Random (Mean # w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> (Mean # w) -> Mean # w)
-> (Double, Mean # w) -> Mean # w
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Double -> (Mean # w) -> Mean # w
forall c x. Double -> (c # x) -> c # x
(/>) ((Double, Mean # w) -> Mean # w)
-> ([(Mean # w, Double)] -> (Double, Mean # w))
-> [(Mean # w, Double)]
-> Mean # w
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Mean # w, Double) -> (Double, Mean # w) -> (Double, Mean # w))
-> (Double, Mean # w) -> [(Mean # w, Double)] -> (Double, Mean # w)
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Mean # w, Double) -> (Double, Mean # w) -> (Double, Mean # w)
foldfun (-Double
1,Mean # w
0) ([(Mean # w, Double)] -> Random (Mean # w))
-> [(Mean # w, Double)] -> Random (Mean # w)
forall a b. (a -> b) -> a -> b
$ [Mean # w] -> [Double] -> [(Mean # w, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Mean # w]
mws [Double]
mys

-- | Contrastive divergence on harmoniums (<https://www.mitpressjournals.org/doi/abs/10.1162/089976602760128018?casa_token=x_Twj1HaXcMAAAAA:7-Oq181aubCFwpG-f8Lo1wRKvGnmujzl8zjn9XbeO5nGhfvKCCQjsu4K4pJCkMNYUYWqc2qG7TRXBg Hinton, 2019>).
contrastiveDivergence
    :: ( Generative Natural z, ExponentialFamily z, Translation w x
       , Generative Natural w, ExponentialFamily y, Translation z y
       , LegendreExponentialFamily w, Bilinear f y x, Map Natural f x y
       , Map Natural f y x, SamplePoint y ~ SamplePoint z
       , SamplePoint x ~ SamplePoint w, ExponentialFamily x )
      => Int -- ^ The number of contrastive divergence steps
      -> Sample z -- ^ The initial states of the Gibbs chains
      -> Natural # AffineHarmonium f y x z w -- ^ The harmonium
      -> Random (Mean # AffineHarmonium f y x z w) -- ^ The gradient estimate
contrastiveDivergence :: Int
-> Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Random (Mean # AffineHarmonium f y x z w)
contrastiveDivergence Int
cdn Sample z
zs Natural # AffineHarmonium f y x z w
hrm = do
    [(SamplePoint z, SamplePoint w)]
xzs0 <- (Natural # AffineHarmonium f y x z w)
-> Sample z -> Random (Sample (z, w))
forall (f :: Type -> Type -> Type) x y z w.
(ExponentialFamily z, Map Natural f x y, Manifold w,
 SamplePoint y ~ SamplePoint z, Translation w x,
 Generative Natural w, ExponentialFamily y, Bilinear f y x,
 LegendreExponentialFamily w) =>
(Natural # AffineHarmonium f y x z w)
-> Sample z -> Random (Sample (z, w))
initialPass Natural # AffineHarmonium f y x z w
hrm Sample z
zs
    [(SamplePoint z, SamplePoint w)]
xzs1 <- Int
-> ([(SamplePoint z, SamplePoint w)]
    -> Random [(SamplePoint z, SamplePoint w)])
-> [(SamplePoint z, SamplePoint w)]
-> Random [(SamplePoint z, SamplePoint w)]
forall (m :: Type -> Type) x.
Monad m =>
Int -> (x -> m x) -> x -> m x
iterateM' Int
cdn ((Natural # AffineHarmonium f y x z w)
-> Sample (z, w) -> Random (Sample (z, w))
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Translation z y,
 Translation w x, SamplePoint z ~ SamplePoint y,
 Generative Natural w, ExponentialFamily y,
 SamplePoint x ~ SamplePoint w, Bilinear f y x, Map Natural f y x,
 ExponentialFamily x, Generative Natural z) =>
(Natural # AffineHarmonium f y x z w)
-> Sample (z, w) -> Random (Sample (z, w))
gibbsPass Natural # AffineHarmonium f y x z w
hrm) [(SamplePoint z, SamplePoint w)]
xzs0
    (Mean # AffineHarmonium f y x z w)
-> Random (Mean # AffineHarmonium f y x z w)
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Mean # AffineHarmonium f y x z w)
 -> Random (Mean # AffineHarmonium f y x z w))
-> (Mean # AffineHarmonium f y x z w)
-> Random (Mean # AffineHarmonium f y x z w)
forall a b. (a -> b) -> a -> b
$ Sample (AffineHarmonium f y x z w)
-> Sample (AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall x. ExponentialFamily x => Sample x -> Sample x -> Mean # x
stochasticRelativeEntropyDifferential [(SamplePoint z, SamplePoint w)]
Sample (AffineHarmonium f y x z w)
xzs0 [(SamplePoint z, SamplePoint w)]
Sample (AffineHarmonium f y x z w)
xzs1


--- Expectation Maximization ---


-- | A single iteration of EM for 'Harmonium' based models.
expectationMaximization
    :: ( DuallyFlatExponentialFamily (AffineHarmonium f y x z w)
       , ExponentialFamily z, Map Natural f x y, Bilinear f y x
       , Translation z y, Translation w x, LegendreExponentialFamily w )
    => Sample z
    -> Natural # AffineHarmonium f y x z w
    -> Natural # AffineHarmonium f y x z w
expectationMaximization :: Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
expectationMaximization Sample z
zs Natural # AffineHarmonium f y x z w
hrm = (Mean # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
forall c d x. Transition c d x => (c # x) -> d # x
transition ((Mean # AffineHarmonium f y x z w)
 -> Natural # AffineHarmonium f y x z w)
-> (Mean # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
forall a b. (a -> b) -> a -> b
$ Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Bilinear f y x,
 Translation z y, Translation w x, LegendreExponentialFamily w) =>
Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
expectationStep Sample z
zs Natural # AffineHarmonium f y x z w
hrm

-- | Ascent of the EM objective on harmoniums for when the expectation
-- step can't be computed in closed-form. The convergent harmonium distribution
-- of the output harmonium-list is the result of 1 iteration of the EM
-- algorithm.
expectationMaximizationAscent
    :: ( LegendreExponentialFamily (AffineHarmonium f y x z w)
       , ExponentialFamily z, Map Natural f x y, Bilinear f y x
       , Translation z y, Translation w x, LegendreExponentialFamily w )
    => Double
    -> GradientPursuit
    -> Sample z
    -> Natural # AffineHarmonium f y x z w
    -> [Natural # AffineHarmonium f y x z w]
expectationMaximizationAscent :: Double
-> GradientPursuit
-> Sample z
-> (Natural # AffineHarmonium f y x z w)
-> [Natural # AffineHarmonium f y x z w]
expectationMaximizationAscent Double
eps GradientPursuit
gp Sample z
zs Natural # AffineHarmonium f y x z w
nhrm =
    let mhrm' :: Mean # AffineHarmonium f y x z w
mhrm' = Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Bilinear f y x,
 Translation z y, Translation w x, LegendreExponentialFamily w) =>
Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
expectationStep Sample z
zs Natural # AffineHarmonium f y x z w
nhrm
     in ((Natural # AffineHarmonium f y x z w)
 -> Natural #* AffineHarmonium f y x z w)
-> Double
-> GradientPursuit
-> (Natural # AffineHarmonium f y x z w)
-> [Natural # AffineHarmonium f y x z w]
forall x c.
Manifold x =>
((c # x) -> c #* x)
-> Double -> GradientPursuit -> (c # x) -> [c # x]
vanillaGradientSequence ((Mean # AffineHarmonium f y x z w)
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall x.
LegendreExponentialFamily x =>
(Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # AffineHarmonium f y x z w
mhrm') (-Double
eps) GradientPursuit
gp Natural # AffineHarmonium f y x z w
nhrm

-- | Ascent of the EM objective on harmoniums for when the expectation
-- step can't be computed in closed-form. The convergent harmonium distribution
-- of the output harmonium-list is the result of 1 iteration of the EM
-- algorithm.
gibbsExpectationMaximization
    :: ( ExponentialFamily z, Map Natural f x y, Manifold w, Map Natural f y x
       , Translation z y, Translation w x, SamplePoint y ~ SamplePoint z
       , SamplePoint w ~ SamplePoint x
       , ExponentialFamily y, Generative Natural w, ExponentialFamily x
       , Generative Natural z, Manifold (AffineHarmonium f y x z w)
       , Bilinear f y x, LegendreExponentialFamily w )
    => Double
    -> Int
    -> Int
    -> GradientPursuit
    -> Sample z -- ^ Observations
    -> Natural # AffineHarmonium f y x z w -- ^ Current Harmonium
    -> Chain Random (Natural # AffineHarmonium f y x z w) -- ^ Harmonium Chain
gibbsExpectationMaximization :: Double
-> Int
-> Int
-> GradientPursuit
-> Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Chain Random (Natural # AffineHarmonium f y x z w)
gibbsExpectationMaximization Double
eps Int
cdn Int
nbtch GradientPursuit
gp Sample z
zs0 Natural # AffineHarmonium f y x z w
nhrm0 =
    let mhrm0 :: Mean # AffineHarmonium f y x z w
mhrm0 = Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Bilinear f y x,
 Translation z y, Translation w x, LegendreExponentialFamily w) =>
Sample z
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
expectationStep Sample z
zs0 Natural # AffineHarmonium f y x z w
nhrm0
     in (Natural # AffineHarmonium f y x z w)
-> Circuit
     Random
     (Natural # AffineHarmonium f y x z w)
     (Natural # AffineHarmonium f y x z w)
-> Chain Random (Natural # AffineHarmonium f y x z w)
forall (m :: Type -> Type) x.
Monad m =>
x -> Circuit m x x -> Chain m x
chainCircuit Natural # AffineHarmonium f y x z w
nhrm0 (Circuit
   Random
   (Natural # AffineHarmonium f y x z w)
   (Natural # AffineHarmonium f y x z w)
 -> Chain Random (Natural # AffineHarmonium f y x z w))
-> Circuit
     Random
     (Natural # AffineHarmonium f y x z w)
     (Natural # AffineHarmonium f y x z w)
-> Chain Random (Natural # AffineHarmonium f y x z w)
forall a b. (a -> b) -> a -> b
$ proc Natural # AffineHarmonium f y x z w
nhrm -> do
         Sample z
zs <- Int -> Sample z -> Chain Random (Sample z)
forall x. Int -> [x] -> Chain Random [x]
minibatcher Int
nbtch Sample z
zs0 -< ()
         [(SamplePoint z, SamplePoint x)]
xzs0 <- ((Natural # AffineHarmonium f y x z w, Sample z)
 -> Random [(SamplePoint z, SamplePoint x)])
-> Circuit
     Random
     (Natural # AffineHarmonium f y x z w, Sample z)
     [(SamplePoint z, SamplePoint x)]
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Circuit m a b
arrM (((Natural # AffineHarmonium f y x z w)
 -> Sample z -> Random [(SamplePoint z, SamplePoint x)])
-> (Natural # AffineHarmonium f y x z w, Sample z)
-> Random [(SamplePoint z, SamplePoint x)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Natural # AffineHarmonium f y x z w)
-> Sample z -> Random [(SamplePoint z, SamplePoint x)]
forall (f :: Type -> Type -> Type) x y z w.
(ExponentialFamily z, Map Natural f x y, Manifold w,
 SamplePoint y ~ SamplePoint z, Translation w x,
 Generative Natural w, ExponentialFamily y, Bilinear f y x,
 LegendreExponentialFamily w) =>
(Natural # AffineHarmonium f y x z w)
-> Sample z -> Random (Sample (z, w))
initialPass) -< (Natural # AffineHarmonium f y x z w
nhrm,Sample z
zs)
         [(SamplePoint z, SamplePoint x)]
xzs1 <- ((Natural # AffineHarmonium f y x z w,
  [(SamplePoint z, SamplePoint x)])
 -> Random [(SamplePoint z, SamplePoint x)])
-> Circuit
     Random
     (Natural # AffineHarmonium f y x z w,
      [(SamplePoint z, SamplePoint x)])
     [(SamplePoint z, SamplePoint x)]
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Circuit m a b
arrM (\(Natural # AffineHarmonium f y x z w
x,[(SamplePoint z, SamplePoint x)]
y) -> Int
-> ([(SamplePoint z, SamplePoint x)]
    -> Random [(SamplePoint z, SamplePoint x)])
-> [(SamplePoint z, SamplePoint x)]
-> Random [(SamplePoint z, SamplePoint x)]
forall (m :: Type -> Type) x.
Monad m =>
Int -> (x -> m x) -> x -> m x
iterateM' Int
cdn ((Natural # AffineHarmonium f y x z w)
-> Sample (z, w) -> Random (Sample (z, w))
forall z (f :: Type -> Type -> Type) x y w.
(ExponentialFamily z, Map Natural f x y, Translation z y,
 Translation w x, SamplePoint z ~ SamplePoint y,
 Generative Natural w, ExponentialFamily y,
 SamplePoint x ~ SamplePoint w, Bilinear f y x, Map Natural f y x,
 ExponentialFamily x, Generative Natural z) =>
(Natural # AffineHarmonium f y x z w)
-> Sample (z, w) -> Random (Sample (z, w))
gibbsPass Natural # AffineHarmonium f y x z w
x) [(SamplePoint z, SamplePoint x)]
y) -< (Natural # AffineHarmonium f y x z w
nhrm,[(SamplePoint z, SamplePoint x)]
xzs0)
         let dff :: Mean # AffineHarmonium f y x z w
dff = Mean # AffineHarmonium f y x z w
mhrm0 (Mean # AffineHarmonium f y x z w)
-> (Mean # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall a. Num a => a -> a -> a
- Sample (AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic [(SamplePoint z, SamplePoint x)]
Sample (AffineHarmonium f y x z w)
xzs1
         Double
-> GradientPursuit
-> Circuit
     Random
     (Natural # AffineHarmonium f y x z w,
      Natural # AffineHarmonium f y x z w)
     (Natural # AffineHarmonium f y x z w)
forall (m :: Type -> Type) x c.
(Monad m, Manifold x) =>
Double -> GradientPursuit -> Circuit m (c # x, c # x) (c # x)
gradientCircuit Double
eps GradientPursuit
gp -< (Natural # AffineHarmonium f y x z w
nhrm,(Natural #* AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
forall x c. Manifold x => (c #* x) -> c # x
vanillaGradient Natural #* AffineHarmonium f y x z w
Mean # AffineHarmonium f y x z w
dff)

latentProcessExpectationStep
    :: ( ConjugatedLikelihood g x x w w, ConjugatedLikelihood f y x z w
       , Transition Natural Mean w, Transition Natural Mean (AffineHarmonium g x x w w)
       , Manifold (AffineHarmonium g x x w w)
       , Bilinear g x x, Map Natural f x y, Bilinear f y x
       , SamplePoint y ~ SamplePoint z )
    => Observations (LatentProcess f g y x z w)
    -> Natural # LatentProcess f g y x z w
    -> (Mean # w, Mean # AffineHarmonium f y x z w, Mean # AffineHarmonium g x x w w)
latentProcessExpectationStep :: Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> (Mean # w, Mean # AffineHarmonium f y x z w,
    Mean # AffineHarmonium g x x w w)
latentProcessExpectationStep Observations (LatentProcess f g y x z w)
zss Natural # LatentProcess f g y x z w
ltnt =
    let (Natural # w
prr,Natural # Affine f y z x
emsn,Natural # Affine g x w x
trns) = (Natural # LatentProcess f g y x z w)
-> (Natural # w, Natural # Affine f y z x,
    Natural # Affine g x w x)
forall z w (f :: Type -> Type -> Type) y x
       (g :: Type -> Type -> Type) c.
(Manifold z, Manifold w, Manifold (f y x), Manifold (g x x)) =>
(c # LatentProcess f g y x z w)
-> (c # w, c # Affine f y z x, c # Affine g x w x)
splitLatentProcess Natural # LatentProcess f g y x z w
ltnt
        ([[Natural # w]]
smthss,[[Natural # AffineHarmonium g x x w w]]
hrmss) = [([Natural # w], [Natural # AffineHarmonium g x x w w])]
-> ([[Natural # w]], [[Natural # AffineHarmonium g x x w w]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Natural # w], [Natural # AffineHarmonium g x x w w])]
 -> ([[Natural # w]], [[Natural # AffineHarmonium g x x w w]]))
-> [([Natural # w], [Natural # AffineHarmonium g x x w w])]
-> ([[Natural # w]], [[Natural # AffineHarmonium g x x w w]])
forall a b. (a -> b) -> a -> b
$ (Natural # w)
-> (Natural # Affine f y z x)
-> (Natural # Affine g x w x)
-> [SamplePoint z]
-> ([Natural # w], [Natural # AffineHarmonium g x x w w])
forall (g :: Type -> Type -> Type) x w (f :: Type -> Type -> Type)
       y z.
(ConjugatedLikelihood g x x w w, Bilinear g x x,
 ConjugatedLikelihood f y x z w, Bilinear f y x, Map Natural g x x,
 Map Natural f x y) =>
(Natural # w)
-> (Natural # Affine f y z x)
-> (Natural # Affine g x w x)
-> Sample z
-> ([Natural # w], [Natural # AffineHarmonium g x x w w])
conjugatedSmoothing0 Natural # w
prr Natural # Affine f y z x
emsn Natural # Affine g x w x
trns ([SamplePoint z]
 -> ([Natural # w], [Natural # AffineHarmonium g x x w w]))
-> [[SamplePoint z]]
-> [([Natural # w], [Natural # AffineHarmonium g x x w w])]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[SamplePoint z]]
Observations (LatentProcess f g y x z w)
zss
        mprr :: Mean # w
mprr = [Mean # w] -> Mean # w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Mean # w] -> Mean # w) -> [Mean # w] -> Mean # w
forall a b. (a -> b) -> a -> b
$ (Natural # w) -> Mean # w
forall c x. Transition c Mean x => (c # x) -> Mean # x
toMean ((Natural # w) -> Mean # w)
-> ([Natural # w] -> Natural # w) -> [Natural # w] -> Mean # w
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural # w] -> Natural # w
forall a. [a] -> a
head ([Natural # w] -> Mean # w) -> [[Natural # w]] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Natural # w]]
smthss
        mtrns :: Mean # AffineHarmonium g x x w w
mtrns = [Mean # AffineHarmonium g x x w w]
-> Mean # AffineHarmonium g x x w w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Mean # AffineHarmonium g x x w w]
 -> Mean # AffineHarmonium g x x w w)
-> [Mean # AffineHarmonium g x x w w]
-> Mean # AffineHarmonium g x x w w
forall a b. (a -> b) -> a -> b
$ (Natural # AffineHarmonium g x x w w)
-> Mean # AffineHarmonium g x x w w
forall c x. Transition c Mean x => (c # x) -> Mean # x
toMean ((Natural # AffineHarmonium g x x w w)
 -> Mean # AffineHarmonium g x x w w)
-> [Natural # AffineHarmonium g x x w w]
-> [Mean # AffineHarmonium g x x w w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Natural # AffineHarmonium g x x w w]]
-> [Natural # AffineHarmonium g x x w w]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[Natural # AffineHarmonium g x x w w]]
hrmss
        mws :: [Mean # w]
mws = (Natural # w) -> Mean # w
forall c x. Transition c Mean x => (c # x) -> Mean # x
toMean ((Natural # w) -> Mean # w) -> [Natural # w] -> [Mean # w]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Natural # w]] -> [Natural # w]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[Natural # w]]
smthss
        mzs :: [Mean # z]
mzs = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint z -> Mean # z) -> [SamplePoint z] -> [Mean # z]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[SamplePoint z]] -> [SamplePoint z]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[SamplePoint z]]
Observations (LatentProcess f g y x z w)
zss
        mys :: [Mean # y]
mys = (Mean # z) -> Mean # y
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # z) -> Mean # y) -> [Mean # z] -> [Mean # y]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # z]
mzs
        mxs :: [Mean # x]
mxs = (Mean # w) -> Mean # x
forall z y c. Translation z y => (c # z) -> c # y
anchor ((Mean # w) -> Mean # x) -> [Mean # w] -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # w]
mws
        memsn :: Mean # AffineHarmonium f y x z w
memsn = (Mean # z)
-> (Mean # f y x) -> (Mean # w) -> Mean # AffineHarmonium f y x z w
forall w z (f :: Type -> Type -> Type) y x c.
(Manifold w, Manifold z, Manifold (f y x)) =>
(c # z) -> (c # f y x) -> (c # w) -> c # AffineHarmonium f y x z w
joinHarmonium ([Mean # z] -> Mean # z
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # z]
mzs) ([Mean # y]
mys [Mean # y] -> [Mean # x] -> Mean # f y x
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
[c # y] -> [c # x] -> c # f y x
>$< [Mean # x]
mxs) ([Mean # w] -> Mean # w
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [Mean # w]
mws)
     in (Mean # w
mprr,Mean # AffineHarmonium f y x z w
memsn,Mean # AffineHarmonium g x x w w
mtrns)

-- | Direct expectation maximization for 'LatentProcess'es.
latentProcessExpectationMaximization
    :: ( ConjugatedLikelihood g x x w w, ConjugatedLikelihood f y x z w
       , Transition Natural Mean w, Transition Natural Mean (AffineHarmonium g x x w w)
       , Transition Mean Natural w
       , Transition Mean Natural (AffineHarmonium f y x z w)
       , Transition Mean Natural (AffineHarmonium g x x w w)
       , Manifold (AffineHarmonium g x x w w)
       , Bilinear g x x, Map Natural f x y, Bilinear f y x
       , SamplePoint y ~ SamplePoint z )
    => Observations (LatentProcess f g y x z w)
    -> Natural # LatentProcess f g y x z w
    -> Natural # LatentProcess f g y x z w
latentProcessExpectationMaximization :: Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> Natural # LatentProcess f g y x z w
latentProcessExpectationMaximization Observations (LatentProcess f g y x z w)
zss Natural # LatentProcess f g y x z w
ltnt =
    let (Mean # w
mprr,Mean # AffineHarmonium f y x z w
memsn,Mean # AffineHarmonium g x x w w
mtrns) = Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> (Mean # w, Mean # AffineHarmonium f y x z w,
    Mean # AffineHarmonium g x x w w)
forall (g :: Type -> Type -> Type) x w (f :: Type -> Type -> Type)
       y z.
(ConjugatedLikelihood g x x w w, ConjugatedLikelihood f y x z w,
 Transition Natural Mean w,
 Transition Natural Mean (AffineHarmonium g x x w w),
 Manifold (AffineHarmonium g x x w w), Bilinear g x x,
 Map Natural f x y, Bilinear f y x,
 SamplePoint y ~ SamplePoint z) =>
Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> (Mean # w, Mean # AffineHarmonium f y x z w,
    Mean # AffineHarmonium g x x w w)
latentProcessExpectationStep Observations (LatentProcess f g y x z w)
zss Natural # LatentProcess f g y x z w
ltnt
        prr' :: Natural # w
prr' = (Mean # w) -> Natural # w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural Mean # w
mprr
        emsn' :: Natural # Affine f y z x
emsn' = (Natural # Affine f y z x, Natural # w) -> Natural # Affine f y z x
forall a b. (a, b) -> a
fst ((Natural # Affine f y z x, Natural # w)
 -> Natural # Affine f y z x)
-> ((Natural # AffineHarmonium f y x z w)
    -> (Natural # Affine f y z x, Natural # w))
-> (Natural # AffineHarmonium f y x z w)
-> Natural # Affine f y z x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f y x z w)
-> (Natural # Affine f y z x, Natural # w)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f y x z w) -> Natural # Affine f y z x)
-> (Natural # AffineHarmonium f y x z w)
-> Natural # Affine f y z x
forall a b. (a -> b) -> a -> b
$ (Mean # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f y x z w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural Mean # AffineHarmonium f y x z w
memsn
        trns' :: Natural # Affine g x w x
trns' = (Natural # Affine g x w x, Natural # w) -> Natural # Affine g x w x
forall a b. (a, b) -> a
fst ((Natural # Affine g x w x, Natural # w)
 -> Natural # Affine g x w x)
-> ((Natural # AffineHarmonium g x x w w)
    -> (Natural # Affine g x w x, Natural # w))
-> (Natural # AffineHarmonium g x x w w)
-> Natural # Affine g x w x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium g x x w w)
-> (Natural # Affine g x w x, Natural # w)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium g x x w w) -> Natural # Affine g x w x)
-> (Natural # AffineHarmonium g x x w w)
-> Natural # Affine g x w x
forall a b. (a -> b) -> a -> b
$ (Mean # AffineHarmonium g x x w w)
-> Natural # AffineHarmonium g x x w w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural Mean # AffineHarmonium g x x w w
mtrns
     in (Natural # w)
-> (Natural # Affine f y z x)
-> (Natural # Affine g x w x)
-> Natural # LatentProcess f g y x z w
forall z w (f :: Type -> Type -> Type) y x
       (g :: Type -> Type -> Type) c.
(Manifold z, Manifold w, Manifold (f y x), Manifold (g x x)) =>
(c # w)
-> (c # Affine f y z x)
-> (c # Affine g x w x)
-> c # LatentProcess f g y x z w
joinLatentProcess Natural # w
prr' Natural # Affine f y z x
emsn' Natural # Affine g x w x
trns'

-- | Expectation maximization for 'LatentProcess'es approximated through
-- gradient ascent.
latentProcessExpectationMaximizationAscent
    :: ( ConjugatedLikelihood g x x w w, ConjugatedLikelihood f y x z w
       , DuallyFlatExponentialFamily w
       , LegendreExponentialFamily (AffineHarmonium f y x z w)
       , LegendreExponentialFamily (AffineHarmonium g x x w w)
       , Bilinear g x x, Map Natural f x y, Bilinear f y x
       , SamplePoint y ~ SamplePoint z )
    => Double
    -> Int
    -> GradientPursuit
    -> [Sample z]
    -> Natural # LatentProcess f g y x z w
    -> Natural # LatentProcess f g y x z w
latentProcessExpectationMaximizationAscent :: Double
-> Int
-> GradientPursuit
-> [Sample z]
-> (Natural # LatentProcess f g y x z w)
-> Natural # LatentProcess f g y x z w
latentProcessExpectationMaximizationAscent Double
eps Int
nstps GradientPursuit
gp [Sample z]
zss Natural # LatentProcess f g y x z w
ltnt =
    let (Mean # w
mprr,Mean # AffineHarmonium f y x z w
mehrm,Mean # AffineHarmonium g x x w w
mthrm) = Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> (Mean # w, Mean # AffineHarmonium f y x z w,
    Mean # AffineHarmonium g x x w w)
forall (g :: Type -> Type -> Type) x w (f :: Type -> Type -> Type)
       y z.
(ConjugatedLikelihood g x x w w, ConjugatedLikelihood f y x z w,
 Transition Natural Mean w,
 Transition Natural Mean (AffineHarmonium g x x w w),
 Manifold (AffineHarmonium g x x w w), Bilinear g x x,
 Map Natural f x y, Bilinear f y x,
 SamplePoint y ~ SamplePoint z) =>
Observations (LatentProcess f g y x z w)
-> (Natural # LatentProcess f g y x z w)
-> (Mean # w, Mean # AffineHarmonium f y x z w,
    Mean # AffineHarmonium g x x w w)
latentProcessExpectationStep [Sample z]
Observations (LatentProcess f g y x z w)
zss Natural # LatentProcess f g y x z w
ltnt
        (Natural # w
nprr,Natural # Affine f y z x
nemsn,Natural # Affine g x w x
ntrns) = (Natural # LatentProcess f g y x z w)
-> (Natural # w, Natural # Affine f y z x,
    Natural # Affine g x w x)
forall z w (f :: Type -> Type -> Type) y x
       (g :: Type -> Type -> Type) c.
(Manifold z, Manifold w, Manifold (f y x), Manifold (g x x)) =>
(c # LatentProcess f g y x z w)
-> (c # w, c # Affine f y z x, c # Affine g x w x)
splitLatentProcess Natural # LatentProcess f g y x z w
ltnt
        neql0 :: Natural # w
neql0 = (Mean # w) -> Natural # w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural ((Mean # w) -> Natural # w)
-> ((Mean # Affine f y z x, Mean # w) -> Mean # w)
-> (Mean # Affine f y z x, Mean # w)
-> Natural # w
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # Affine f y z x, Mean # w) -> Mean # w
forall a b. (a, b) -> b
snd ((Mean # Affine f y z x, Mean # w) -> Natural # w)
-> (Mean # Affine f y z x, Mean # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Mean # AffineHarmonium f y x z w)
-> (Mean # First (AffineHarmonium f y x z w),
    Mean # Second (AffineHarmonium f y x z w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Mean # AffineHarmonium f y x z w
mehrm
        neql1 :: Natural # w
neql1 = (Mean # w) -> Natural # w
forall c x. Transition c Natural x => (c # x) -> Natural # x
toNatural ((Mean # w) -> Natural # w)
-> ((Mean # Affine g x w x, Mean # w) -> Mean # w)
-> (Mean # Affine g x w x, Mean # w)
-> Natural # w
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # Affine g x w x, Mean # w) -> Mean # w
forall a b. (a, b) -> b
snd ((Mean # Affine g x w x, Mean # w) -> Natural # w)
-> (Mean # Affine g x w x, Mean # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Mean # AffineHarmonium g x x w w)
-> (Mean # First (AffineHarmonium g x x w w),
    Mean # Second (AffineHarmonium g x x w w))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split Mean # AffineHarmonium g x x w w
mthrm
        nehrm :: Natural # AffineHarmonium f y x z w
nehrm = (Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
joinConjugatedHarmonium Natural # Affine f y z x
nemsn Natural # w
neql0
        nthrm :: Natural # AffineHarmonium g x x w w
nthrm = (Natural # Affine g x w x)
-> (Natural # w) -> Natural # AffineHarmonium g x x w w
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
joinConjugatedHarmonium Natural # Affine g x w x
ntrns Natural # w
neql1
        nprr' :: Natural # w
nprr' = ([Natural # w] -> Int -> Natural # w
forall a. [a] -> Int -> a
!! Int
nstps)
            ([Natural # w] -> Natural # w) -> [Natural # w] -> Natural # w
forall a b. (a -> b) -> a -> b
$ ((Natural # w) -> Natural #* w)
-> Double -> GradientPursuit -> (Natural # w) -> [Natural # w]
forall x c.
Manifold x =>
((c # x) -> c #* x)
-> Double -> GradientPursuit -> (c # x) -> [c # x]
vanillaGradientSequence ((Mean # w) -> (Natural # w) -> Mean # w
forall x.
LegendreExponentialFamily x =>
(Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # w
mprr) (-Double
eps) GradientPursuit
gp Natural # w
nprr
        nemsn' :: Natural # Affine f y z x
nemsn' = (Natural # Affine f y z x, Natural # w) -> Natural # Affine f y z x
forall a b. (a, b) -> a
fst ((Natural # Affine f y z x, Natural # w)
 -> Natural # Affine f y z x)
-> ([Natural # AffineHarmonium f y x z w]
    -> (Natural # Affine f y z x, Natural # w))
-> [Natural # AffineHarmonium f y x z w]
-> Natural # Affine f y z x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f y x z w)
-> (Natural # Affine f y z x, Natural # w)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f y x z w)
 -> (Natural # Affine f y z x, Natural # w))
-> ([Natural # AffineHarmonium f y x z w]
    -> Natural # AffineHarmonium f y x z w)
-> [Natural # AffineHarmonium f y x z w]
-> (Natural # Affine f y z x, Natural # w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Natural # AffineHarmonium f y x z w]
-> Int -> Natural # AffineHarmonium f y x z w
forall a. [a] -> Int -> a
!! Int
nstps)
            ([Natural # AffineHarmonium f y x z w] -> Natural # Affine f y z x)
-> [Natural # AffineHarmonium f y x z w]
-> Natural # Affine f y z x
forall a b. (a -> b) -> a -> b
$ ((Natural # AffineHarmonium f y x z w)
 -> Natural #* AffineHarmonium f y x z w)
-> Double
-> GradientPursuit
-> (Natural # AffineHarmonium f y x z w)
-> [Natural # AffineHarmonium f y x z w]
forall x c.
Manifold x =>
((c # x) -> c #* x)
-> Double -> GradientPursuit -> (c # x) -> [c # x]
vanillaGradientSequence ((Mean # AffineHarmonium f y x z w)
-> (Natural # AffineHarmonium f y x z w)
-> Mean # AffineHarmonium f y x z w
forall x.
LegendreExponentialFamily x =>
(Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # AffineHarmonium f y x z w
mehrm) (-Double
eps) GradientPursuit
gp Natural # AffineHarmonium f y x z w
nehrm
        ntrns' :: Natural # Affine g x w x
ntrns' = (Natural # Affine g x w x, Natural # w) -> Natural # Affine g x w x
forall a b. (a, b) -> a
fst ((Natural # Affine g x w x, Natural # w)
 -> Natural # Affine g x w x)
-> ([Natural # AffineHarmonium g x x w w]
    -> (Natural # Affine g x w x, Natural # w))
-> [Natural # AffineHarmonium g x x w w]
-> Natural # Affine g x w x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium g x x w w)
-> (Natural # Affine g x w x, Natural # w)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium g x x w w)
 -> (Natural # Affine g x w x, Natural # w))
-> ([Natural # AffineHarmonium g x x w w]
    -> Natural # AffineHarmonium g x x w w)
-> [Natural # AffineHarmonium g x x w w]
-> (Natural # Affine g x w x, Natural # w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Natural # AffineHarmonium g x x w w]
-> Int -> Natural # AffineHarmonium g x x w w
forall a. [a] -> Int -> a
!! Int
nstps)
            ([Natural # AffineHarmonium g x x w w] -> Natural # Affine g x w x)
-> [Natural # AffineHarmonium g x x w w]
-> Natural # Affine g x w x
forall a b. (a -> b) -> a -> b
$ ((Natural # AffineHarmonium g x x w w)
 -> Natural #* AffineHarmonium g x x w w)
-> Double
-> GradientPursuit
-> (Natural # AffineHarmonium g x x w w)
-> [Natural # AffineHarmonium g x x w w]
forall x c.
Manifold x =>
((c # x) -> c #* x)
-> Double -> GradientPursuit -> (c # x) -> [c # x]
vanillaGradientSequence ((Mean # AffineHarmonium g x x w w)
-> (Natural # AffineHarmonium g x x w w)
-> Mean # AffineHarmonium g x x w w
forall x.
LegendreExponentialFamily x =>
(Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # AffineHarmonium g x x w w
mthrm) (-Double
eps) GradientPursuit
gp Natural # AffineHarmonium g x x w w
nthrm
     in (Natural # w)
-> (Natural # Affine f y z x)
-> (Natural # Affine g x w x)
-> Natural # LatentProcess f g y x z w
forall z w (f :: Type -> Type -> Type) y x
       (g :: Type -> Type -> Type) c.
(Manifold z, Manifold w, Manifold (f y x), Manifold (g x x)) =>
(c # w)
-> (c # Affine f y z x)
-> (c # Affine g x w x)
-> c # LatentProcess f g y x z w
joinLatentProcess Natural # w
nprr' Natural # Affine f y z x
nemsn' Natural # Affine g x w x
ntrns'