{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
{-# LANGUAGE
    RankNTypes,
    PolyKinds,
    DataKinds,
    TypeOperators,
    FlexibleContexts,
    FlexibleInstances,
    TypeApplications,
    ScopedTypeVariables,
    TypeFamilies
#-}
-- | Infering latent variables in graphical models.
module Goal.Graphical.Inference
    ( -- * Inference
      conjugatedBayesRule
    -- * Recursive
    , conjugatedRecursiveBayesianInference
    -- * Dynamic
    , conjugatedPredictionStep
    , conjugatedForwardStep
    -- * Conjugation
    , regressConjugationParameters
    , conjugationCurve
    ) where

--- Imports ---


-- Goal --

import Goal.Core
import Goal.Geometry
import Goal.Probability

import Goal.Graphical.Models.Harmonium

import qualified Goal.Core.Vector.Storable as S

import Data.List


--- Inference ---


-- | The posterior distribution given a prior and likelihood, where the
-- likelihood is conjugated.
conjugatedBayesRule
    :: forall f y x z w
    . ( Map Natural f x y, Bilinear f y x, ConjugatedLikelihood f y x z w )
    => Natural # Affine f y z x
    -> Natural # w
    -> SamplePoint z
    -> Natural # w
conjugatedBayesRule :: (Natural # Affine f y z x)
-> (Natural # w) -> SamplePoint z -> Natural # w
conjugatedBayesRule Natural # Affine f y z x
lkl Natural # w
prr SamplePoint z
z =
    let pstr :: Natural # Affine f x w y
pstr = (Natural # Affine f x w y, Natural # z) -> Natural # Affine f x w y
forall a b. (a, b) -> a
fst ((Natural # Affine f x w y, Natural # z)
 -> Natural # Affine f x w y)
-> ((Natural # AffineHarmonium f y x z w)
    -> (Natural # Affine f x w y, Natural # z))
-> (Natural # AffineHarmonium f y x z w)
-> Natural # Affine f x w y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x y w z)
-> (Natural # Affine f x w y, Natural # z)
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split ((Natural # AffineHarmonium f x y w z)
 -> (Natural # Affine f x w y, Natural # z))
-> ((Natural # AffineHarmonium f y x z w)
    -> Natural # AffineHarmonium f x y w z)
-> (Natural # AffineHarmonium f y x z w)
-> (Natural # Affine f x w y, Natural # z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f y x z w)
-> Natural # AffineHarmonium f x y w z
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium ((Natural # AffineHarmonium f y x z w) -> Natural # Affine f x w y)
-> (Natural # AffineHarmonium f y x z w)
-> Natural # Affine f x w y
forall a b. (a -> b) -> a -> b
$ (Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
joinConjugatedHarmonium Natural # Affine f y z x
lkl Natural # w
prr
        mz :: Mean # z
        mz :: Mean # z
mz = SamplePoint z -> Mean # z
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint z
z
     in Natural # Affine f x w y
pstr (Natural # Affine f x w y) -> (Natural #* z) -> Natural # w
forall c (f :: Type -> Type -> Type) y x z.
(Map c f y x, Translation z x) =>
(c # f y x) -> (c #* z) -> c # y
>.+> Natural #* z
Mean # z
mz


--- Recursive ---


-- | The posterior distribution given a prior and likelihood, where the
-- likelihood is conjugated.
conjugatedRecursiveBayesianInference
    :: ( Map Natural f x y, Bilinear f y x, ConjugatedLikelihood f y x z w )
    => Natural # Affine f y z x -- ^ Likelihood
    -> Natural # w -- ^ Prior
    -> Sample z -- ^ Observations
    -> [Natural # w] -- ^ Updated prior
conjugatedRecursiveBayesianInference :: (Natural # Affine f y z x)
-> (Natural # w) -> Sample z -> [Natural # w]
conjugatedRecursiveBayesianInference Natural # Affine f y z x
lkl = ((Natural # w) -> SamplePoint z -> Natural # w)
-> (Natural # w) -> Sample z -> [Natural # w]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' ((Natural # Affine f y z x)
-> (Natural # w) -> SamplePoint z -> Natural # w
forall (f :: Type -> Type -> Type) y x z w.
(Map Natural f x y, Bilinear f y x,
 ConjugatedLikelihood f y x z w) =>
(Natural # Affine f y z x)
-> (Natural # w) -> SamplePoint z -> Natural # w
conjugatedBayesRule Natural # Affine f y z x
lkl)


-- Dynamical ---


-- | The predicted distribution given a current distribution and transition
-- distribution, where the transition distribution is (doubly) conjugated.
conjugatedPredictionStep
    :: (ConjugatedLikelihood f x x w w, Bilinear f x x)
    => Natural # Affine f x w x
    -> Natural # w
    -> Natural # w
conjugatedPredictionStep :: (Natural # Affine f x w x) -> (Natural # w) -> Natural # w
conjugatedPredictionStep Natural # Affine f x w x
trns Natural # w
prr =
    (Natural # Affine f x w x, Natural # w) -> Natural # w
forall a b. (a, b) -> b
snd ((Natural # Affine f x w x, Natural # w) -> Natural # w)
-> ((Natural # AffineHarmonium f x x w w)
    -> (Natural # Affine f x w x, Natural # w))
-> (Natural # AffineHarmonium f x x w w)
-> Natural # w
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x x w w)
-> (Natural # Affine f x w x, Natural # w)
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # AffineHarmonium f y x z w)
-> (Natural # Affine f y z x, Natural # w)
splitConjugatedHarmonium ((Natural # AffineHarmonium f x x w w)
 -> (Natural # Affine f x w x, Natural # w))
-> ((Natural # AffineHarmonium f x x w w)
    -> Natural # AffineHarmonium f x x w w)
-> (Natural # AffineHarmonium f x x w w)
-> (Natural # Affine f x w x, Natural # w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural # AffineHarmonium f x x w w)
-> Natural # AffineHarmonium f x x w w
forall (f :: Type -> Type -> Type) y x z w c.
(Bilinear f y x, Manifold z, Manifold w) =>
(c # AffineHarmonium f y x z w) -> c # AffineHarmonium f x y w z
transposeHarmonium
        ((Natural # AffineHarmonium f x x w w) -> Natural # w)
-> (Natural # AffineHarmonium f x x w w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Natural # Affine f x w x)
-> (Natural # w) -> Natural # AffineHarmonium f x x w w
forall (f :: Type -> Type -> Type) y x z w.
ConjugatedLikelihood f y x z w =>
(Natural # Affine f y z x)
-> (Natural # w) -> Natural # AffineHarmonium f y x z w
joinConjugatedHarmonium Natural # Affine f x w x
trns Natural # w
prr

-- | Forward inference based on conjugated models: priors at a previous time are
-- first predicted into the current time, and then updated with Bayes rule.
conjugatedForwardStep
    :: ( ConjugatedLikelihood g x x w w, Bilinear g x x
       , ConjugatedLikelihood f y x z w, Bilinear f y x
       , Map Natural g x x, Map Natural f x y )
    => Natural # Affine g x w x -- ^ Transition Distribution
    -> Natural # Affine f y z x -- ^ Emission Distribution
    -> Natural # w -- ^ Beliefs at time $t-1$
    -> SamplePoint z -- ^ Observation at time $t$
    -> Natural # w -- ^ Beliefs at time $t$
conjugatedForwardStep :: (Natural # Affine g x w x)
-> (Natural # Affine f y z x)
-> (Natural # w)
-> SamplePoint z
-> Natural # w
conjugatedForwardStep Natural # Affine g x w x
trns Natural # Affine f y z x
emsn Natural # w
prr SamplePoint z
z =
    ((Natural # w) -> SamplePoint z -> Natural # w)
-> SamplePoint z -> (Natural # w) -> Natural # w
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Natural # Affine f y z x)
-> (Natural # w) -> SamplePoint z -> Natural # w
forall (f :: Type -> Type -> Type) y x z w.
(Map Natural f x y, Bilinear f y x,
 ConjugatedLikelihood f y x z w) =>
(Natural # Affine f y z x)
-> (Natural # w) -> SamplePoint z -> Natural # w
conjugatedBayesRule Natural # Affine f y z x
emsn) SamplePoint z
z ((Natural # w) -> Natural # w) -> (Natural # w) -> Natural # w
forall a b. (a -> b) -> a -> b
$ (Natural # Affine g x w x) -> (Natural # w) -> Natural # w
forall (f :: Type -> Type -> Type) x w.
(ConjugatedLikelihood f x x w w, Bilinear f x x) =>
(Natural # Affine f x w x) -> (Natural # w) -> Natural # w
conjugatedPredictionStep Natural # Affine g x w x
trns Natural # w
prr


--- Approximate Conjugation ---


-- | Computes the conjugation curve given a set of conjugation parameters,
-- at the given set of points.
conjugationCurve
    :: ExponentialFamily x
    => Double -- ^ Conjugation shift
    -> Natural # x -- ^ Conjugation parameters
    -> Sample x -- ^ Samples points
    -> [Double] -- ^ Conjugation curve at sample points
conjugationCurve :: Double -> (Natural # x) -> Sample x -> [Double]
conjugationCurve Double
rho0 Natural # x
rprms Sample x
mus = (\SamplePoint x
x -> Natural # x
rprms (Natural # x) -> (Natural #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint x
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
rho0) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
mus

-- Linear Least Squares

-- | Returns the conjugation parameters which best satisfy the conjugation
-- equation for the given population code according to linear regression.
regressConjugationParameters
    :: (Map Natural f z x, LegendreExponentialFamily z, ExponentialFamily x)
    => Natural # f z x -- ^ PPC
    -> Sample x -- ^ Sample points
    -> (Double, Natural # x) -- ^ Approximate conjugation parameters
regressConjugationParameters :: (Natural # f z x) -> Sample x -> (Double, Natural # x)
regressConjugationParameters Natural # f z x
lkl Sample x
mus =
    let dpnds :: [Double]
dpnds = Point Natural z -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential (Point Natural z -> Double) -> [Point Natural z] -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural # f z x
lkl (Natural # f z x) -> Sample x -> [Point Natural z]
forall (f :: Type -> Type -> Type) y x.
(Map Natural f y x, ExponentialFamily x) =>
(Natural # f y x) -> Sample x -> [Natural # y]
>$>* Sample x
mus
        indpnds :: [Vector (Dimension x + 1) Double]
indpnds = (Natural # f z x) -> Sample x -> [Vector (Dimension x + 1) Double]
forall k (f :: k -> Type -> Type) x (z :: k).
ExponentialFamily x =>
(Natural # f z x) -> Sample x -> [Vector (Dimension x + 1) Double]
independentVariables0 Natural # f z x
lkl Sample x
mus
        (Vector (1 + 0) Double
rho0,Vector (Dimension x) Double
rprms) = Vector ((1 + 0) + Dimension x) Double
-> (Vector (1 + 0) Double, Vector (Dimension x) Double)
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt (Vector ((1 + 0) + Dimension x) Double
 -> (Vector (1 + 0) Double, Vector (Dimension x) Double))
-> Vector ((1 + 0) + Dimension x) Double
-> (Vector (1 + 0) Double, Vector (Dimension x) Double)
forall a b. (a -> b) -> a -> b
$ [Vector (Dimension x + 1) Double]
-> [Double] -> Vector (Dimension x + 1) Double
forall (l :: Nat).
KnownNat l =>
[Vector l Double] -> [Double] -> Vector l Double
S.linearLeastSquares [Vector (Dimension x + 1) Double]
indpnds [Double]
dpnds
     in (Vector (1 + 0) Double -> Double
forall (n :: Nat) a. Storable a => Vector (1 + n) a -> a
S.head Vector (1 + 0) Double
rho0, Vector (Dimension x) Double -> Natural # x
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector (Dimension x) Double
rprms)

--- Internal ---

independentVariables0
    :: forall f x z . ExponentialFamily x
    => Natural # f z x
    -> Sample x
    -> [S.Vector (Dimension x + 1) Double]
independentVariables0 :: (Natural # f z x) -> Sample x -> [Vector (Dimension x + 1) Double]
independentVariables0 Natural # f z x
_ Sample x
mus =
    let sss :: [Mean # x]
        sss :: [Mean # x]
sss = SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
mus
     in (Double -> Vector 1 Double
forall a. Storable a => a -> Vector 1 a
S.singleton Double
1 Vector 1 Double
-> Vector (Dimension x) Double -> Vector (1 + Dimension x) Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++) (Vector (Dimension x) Double -> Vector (1 + Dimension x) Double)
-> ((Mean # x) -> Vector (Dimension x) Double)
-> (Mean # x)
-> Vector (1 + Dimension x) Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mean # x) -> Vector (Dimension x) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates ((Mean # x) -> Vector (1 + Dimension x) Double)
-> [Mean # x] -> [Vector (1 + Dimension x) Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Mean # x]
sss


---- | The posterior distribution given a prior and likelihood, where the
---- posterior is normalized via numerical integration.
--numericalRecursiveBayesianInference
--    :: forall f z x .
--        ( Map Natural f x z, Map Natural f z x, Bilinear f z x
--        , LegendreExponentialFamily z, ExponentialFamily x, SamplePoint x ~ Double)
--    => Double -- ^ Integral error bound
--    -> Double -- ^ Sample space lower bound
--    -> Double -- ^ Sample space upper bound
--    -> Sample x -- ^ Centralization samples
--    -> [Natural # Affine f z x] -- ^ Likelihoods
--    -> Sample z -- ^ Observations
--    -> (Double -> Double) -- ^ Prior
--    -> (Double -> Double, Double) -- ^ Posterior Density and Log-Partition Function
--numericalRecursiveBayesianInference errbnd mnx mxx xsmps lkls zs prr =
--    let logbm = logBaseMeasure (Proxy @ x)
--        logupst0 x lkl z =
--            (z *<.< snd (splitAffine lkl)) <.> sufficientStatistic x - potential (lkl >.>* x)
--        logupst x = sum $ logbm x : log (prr x) : zipWith (logupst0 x) lkls zs
--        logprt = logIntegralExp errbnd logupst mnx mxx xsmps
--        dns x = exp $ logupst x - logprt
--     in (dns,logprt)