module FRP.Rhine.Bayes where

-- log-domain
import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Population

-- dunai
import qualified Control.Monad.Trans.MSF.Reader as DunaiReader

-- dunai-bayes
import qualified Data.MonadicStreamFunction.Bayes as DunaiBayes

-- rhine
import FRP.Rhine

-- * Inference methods

-- | Run the Sequential Monte Carlo algorithm continuously on a 'ClSF'.
runPopulationCl :: forall m cl a b . Monad m =>
  -- | Number of particles
  Int ->
  -- | Resampler (see 'Control.Monad.Bayes.Population' for some standard choices)
  (forall x . Population m x -> Population m x)
  -- | A signal function modelling the stochastic process on which to perform inference.
  --   @a@ represents observations upon which the model should condition, using e.g. 'score'.
  --   It can also additionally contain hyperparameters.
  --   @b@ is the type of estimated current state.
  -> ClSF (Population m) cl a b
  -> ClSF m cl a [(b, Log Double)]
runPopulationCl :: forall (m :: * -> *) cl a b.
Monad m =>
Int
-> (forall x. Population m x -> Population m x)
-> ClSF (Population m) cl a b
-> ClSF m cl a [(b, Log Double)]
runPopulationCl Int
nParticles forall x. Population m x -> Population m x
resampler = forall (m :: * -> *) r a b.
Monad m =>
MSF m (r, a) b -> MSF (ReaderT r m) a b
DunaiReader.readerS forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a b.
Monad m =>
Int
-> (forall x. Population m x -> Population m x)
-> MSF (Population m) a b
-> MSF m a [(b, Log Double)]
DunaiBayes.runPopulationS Int
nParticles forall x. Population m x -> Population m x
resampler forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r a b.
Monad m =>
MSF (ReaderT r m) a b -> MSF m (r, a) b
DunaiReader.runReaderS

-- * Short standard library of stochastic processes

-- | White noise, that is, an independent normal distribution at every time step.
whiteNoise :: MonadDistribution m => Double -> Behaviour m td Double
whiteNoise :: forall (m :: * -> *) td.
MonadDistribution m =>
Double -> Behaviour m td Double
whiteNoise Double
sigma = forall (m :: * -> *) b cl a. Monad m => m b -> ClSF m cl a b
constMCl forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> m Double
normal Double
0 Double
sigma

-- | Construct a Lévy process from the increment between time steps.
levy ::
  (MonadDistribution m, VectorSpace v (Diff td)) =>
  -- | The increment function at every time step. The argument is the difference between times.
  (Diff td -> m v) ->
  Behaviour m td v
levy :: forall (m :: * -> *) v td.
(MonadDistribution m, VectorSpace v (Diff td)) =>
(Diff td -> m v) -> Behaviour m td v
levy Diff td -> m v
incrementor = forall (m :: * -> *) cl a. Monad m => ClSF m cl a (Diff (Time cl))
sinceLastS forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (m :: * -> *) a b cl. Monad m => (a -> m b) -> ClSF m cl a b
arrMCl Diff td -> m v
incrementor forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall v s (m :: * -> *). (VectorSpace v s, Monad m) => MSF m v v
sumS

-- | The Wiener process, also known as Brownian motion.
wiener, brownianMotion ::
  (MonadDistribution m, Diff td ~ Double) =>
  -- | Time scale of variance.
  Diff td ->
  Behaviour m td Double
wiener :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
Diff td -> Behaviour m td Double
wiener Diff td
timescale = forall (m :: * -> *) v td.
(MonadDistribution m, VectorSpace v (Diff td)) =>
(Diff td -> m v) -> Behaviour m td v
levy forall a b. (a -> b) -> a -> b
$ \Diff td
diffTime -> forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> m Double
normal Double
0 forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
sqrt forall a b. (a -> b) -> a -> b
$ Diff td
diffTime forall a. Fractional a => a -> a -> a
/ Diff td
timescale

brownianMotion :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
Diff td -> Behaviour m td Double
brownianMotion = forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
Diff td -> Behaviour m td Double
wiener

-- | The Wiener process, also known as Brownian motion, with varying variance parameter.
wienerVarying, brownianMotionVarying ::
  (MonadDistribution m, Diff td ~ Double) =>
  BehaviourF m td (Diff td) Double
wienerVarying :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) Double
wienerVarying = proc Diff td
timeScale -> do
  Double
diffTime <- forall (m :: * -> *) cl a. Monad m => ClSF m cl a (Diff (Time cl))
sinceLastS -< ()
  let stdDev :: Double
stdDev = forall a. Floating a => a -> a
sqrt forall a b. (a -> b) -> a -> b
$ Double
diffTime forall a. Fractional a => a -> a -> a
/ Diff td
timeScale
  Double
increment <- if Double
stdDev forall a. Ord a => a -> a -> Bool
> Double
0
    then forall (m :: * -> *) a b. Monad m => (a -> m b) -> MSF m a b
arrM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> m Double
normal Double
0 -< Double
stdDev
    else forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< Double
0
  forall v s (m :: * -> *). (VectorSpace v s, Monad m) => MSF m v v
sumS -< Double
increment

brownianMotionVarying :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) Double
brownianMotionVarying = forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) Double
wienerVarying

-- | The 'wiener' process transformed to the Log domain, also called the geometric Wiener process.
wienerLogDomain ::
  (MonadDistribution m, Diff td ~ Double) =>
  -- | Time scale of variance
  Diff td ->
  Behaviour m td (Log Double)
wienerLogDomain :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
Diff td -> Behaviour m td (Log Double)
wienerLogDomain Diff td
timescale = forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
Diff td -> Behaviour m td Double
wiener Diff td
timescale forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a. a -> Log a
Exp

-- | See 'wienerLogDomain' and 'wienerVarying'.
wienerVaryingLogDomain ::
  (MonadDistribution m, Diff td ~ Double) =>
  BehaviourF m td (Diff td) (Log Double)
wienerVaryingLogDomain :: forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) (Log Double)
wienerVaryingLogDomain = forall (m :: * -> *) td.
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) Double
wienerVarying forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a. a -> Log a
Exp