module Control.Monad.Bayes.Inference.Lazy.WIS where

import Control.Monad.Bayes.Sampler.Lazy (Sampler, weightedsamples)
import Control.Monad.Bayes.Weighted (Weighted)
import Numeric.Log (Log (Exp))
import System.Random (Random (randoms), getStdGen, newStdGen)

-- | Weighted Importance Sampling

-- | Likelihood weighted importance sampling first draws n weighted samples,
--    and then samples a stream of results from that regarded as an empirical distribution
lwis :: Int -> Weighted Sampler a -> IO [a]
lwis :: forall a. Int -> Weighted Sampler a -> IO [a]
lwis Int
n Weighted Sampler a
m = do
  [(a, Log Double)]
xws <- forall a. Weighted Sampler a -> IO [(a, Log Double)]
weightedsamples Weighted Sampler a
m
  let xws' :: [(a, Log Double)]
xws' = forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, Log Double)]
xws Log Double
0
  let max' :: Log Double
max' = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [(a, Log Double)]
xws'
  StdGen
_ <- forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
  [Double]
rs <- forall a g. (Random a, RandomGen g) => g -> [a]
randoms forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Double
r -> forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
>= forall a. a -> Log a
Exp (forall a. Floating a => a -> a
log Double
r) forall a. Num a => a -> a -> a
* Log Double
max') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(a, Log Double)]
xws') [Double]
rs
  where
    accumulate :: Num t => [(a, t)] -> t -> [(a, t)]
    accumulate :: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate ((a
x, t
w) : [(a, t)]
xws) t
a = (a
x, t
w forall a. Num a => a -> a -> a
+ t
a) forall a. a -> [a] -> [a]
: (a
x, t
w forall a. Num a => a -> a -> a
+ t
a) forall a. a -> [a] -> [a]
: forall t a. Num t => [(a, t)] -> t -> [(a, t)]
accumulate [(a, t)]
xws (t
w forall a. Num a => a -> a -> a
+ t
a)
    accumulate [] t
_ = []