{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-missing-export-lists #-} module TestInference where import ConjugatePriors ( betaBernoulli', betaBernoulliAnalytic, gammaNormal', gammaNormalAnalytic, normalNormal', normalNormalAnalytic, ) import Control.Monad (replicateM) import Control.Monad.Bayes.Class (MonadMeasure, posterior) import Control.Monad.Bayes.Enumerator (enumerator) import Control.Monad.Bayes.Inference.SMC import Control.Monad.Bayes.Integrator (normalize) import Control.Monad.Bayes.Integrator qualified as Integrator import Control.Monad.Bayes.Population import Control.Monad.Bayes.Population (collapse, runPopulation) import Control.Monad.Bayes.Sampler.Strict (Sampler, sampleIOfixed) import Control.Monad.Bayes.Sampler.Strict qualified as Sampler import Control.Monad.Bayes.Weighted (Weighted) import Control.Monad.Bayes.Weighted qualified as Weighted import Data.AEq (AEq ((~==))) import Numeric.Log (Log) import Sprinkler (soft) import System.Random.Stateful (IOGenM, StdGen, mkStdGen, newIOGenM) sprinkler :: MonadMeasure m => m Bool sprinkler = Sprinkler.soft -- | Count the number of particles produced by SMC checkParticles :: Int -> Int -> IO Int checkParticles observations particles = sampleIOfixed (fmap length (population $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleMultinomial} Sprinkler.soft)) checkParticlesSystematic :: Int -> Int -> IO Int checkParticlesSystematic observations particles = sampleIOfixed (fmap length (population $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleSystematic} Sprinkler.soft)) checkParticlesStratified :: Int -> Int -> IO Int checkParticlesStratified observations particles = sampleIOfixed (fmap length (population $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleStratified} Sprinkler.soft)) checkTerminateSMC :: IO [(Bool, Log Double)] checkTerminateSMC = sampleIOfixed (population $ smc SMCConfig {numSteps = 2, numParticles = 5, resampler = resampleMultinomial} sprinkler) checkPreserveSMC :: Bool checkPreserveSMC = (enumerator . collapse . smc SMCConfig {numSteps = 2, numParticles = 2, resampler = resampleMultinomial}) sprinkler ~== enumerator sprinkler expectationNearNumeric :: Weighted Integrator.Integrator Double -> Weighted Integrator.Integrator Double -> Double expectationNearNumeric x y = let e1 = Integrator.expectation $ normalize x e2 = Integrator.expectation $ normalize y in (abs (e1 - e2)) expectationNearSampling :: Weighted (Sampler (IOGenM StdGen) IO) Double -> Weighted (Sampler (IOGenM StdGen) IO) Double -> IO Double expectationNearSampling x y = do e1 <- sampleIOfixed $ fmap Sampler.sampleMean $ replicateM 10 $ Weighted.weighted x e2 <- sampleIOfixed $ fmap Sampler.sampleMean $ replicateM 10 $ Weighted.weighted y return (abs (e1 - e2)) testNormalNormal :: [Double] -> IO Bool testNormalNormal n = do let e = expectationNearNumeric (posterior (normalNormal' 1 (1, 10)) n) (normalNormalAnalytic 1 (1, 10) n) return (e < 1e-0) testGammaNormal :: [Double] -> IO Bool testGammaNormal n = do let e = expectationNearNumeric (posterior (gammaNormal' (1, 1)) n) (gammaNormalAnalytic (1, 1) n) return (e < 1e-1) testBetaBernoulli :: [Bool] -> IO Bool testBetaBernoulli bs = do let e = expectationNearNumeric (posterior (betaBernoulli' (1, 1)) bs) (betaBernoulliAnalytic (1, 1) bs) return (e < 1e-1)