{-# LANGUAGE RankNTypes, NoMonomorphismRestriction, BangPatterns #-}
{-# OPTIONS -W #-}
module Language.Hakaru.ImportanceSampler where
-- This is an interpreter that's like Interpreter except conditioning is
-- checked at run time rather than by static types. In other words, we allow
-- models to be compiled whose conditioned parts do not match the observation
-- inputs. In exchange, we get to make Measure an instance of Monad, and we
-- can express models whose number of observations is unknown at compile time.
import Language.Hakaru.Types
import Language.Hakaru.Mixture (Prob, empty, point, Mixture(..))
import Language.Hakaru.Sampler (Sampler, deterministic, smap, sbind)
import qualified System.Random.MWC as MWC
import Control.Monad.Primitive
import Data.Monoid
import Data.Dynamic
import System.IO.Unsafe
import qualified Data.Map.Strict as M
import qualified Data.Number.LogFloat as LF
-- Conditioned sampling
newtype Measure a = Measure { unMeasure :: [Cond] -> Sampler (a, [Cond]) }
bind :: Measure a -> (a -> Measure b) -> Measure b
bind measure continuation =
Measure (\conds ->
sbind (unMeasure measure conds)
(\(a,cds) -> unMeasure (continuation a) cds))
instance Monad Measure where
return x = Measure (\conds -> deterministic (point (x,conds) 1))
(>>=) = bind
updateMixture :: Typeable a => Cond -> Dist a -> Sampler a
updateMixture (Just cond) dist =
case fromDynamic cond of
Just y -> deterministic (point (fromDensity y) density)
where density = LF.logToLogFloat $ logDensity dist y
Nothing -> error "did not get data from dynamic source"
updateMixture Nothing dist = \g -> do e <- distSample dist g
return $ point (fromDensity e) 1
conditioned, unconditioned :: Typeable a => Dist a -> Measure a
conditioned dist = Measure (\(cond:conds) -> smap (\a->(a,conds))
(updateMixture cond dist))
unconditioned dist = Measure (\ conds -> smap (\a->(a,conds))
(updateMixture Nothing dist))
factor :: Prob -> Measure ()
factor p = Measure (\conds -> deterministic (point ((), conds) p))
condition :: Eq b => Measure (a, b) -> b -> Measure a
condition m b' =
Measure (\ conds ->
sbind (unMeasure m conds)
(\ ((a,b), cds) ->
deterministic (if b==b' then point (a,cds) 1 else empty)))
-- Drivers for testing
finish :: Mixture (a, [Cond]) -> Mixture a
finish (Mixture m) = Mixture (M.mapKeysMonotonic (\(a,[]) -> a) m)
empiricalMeasure :: (PrimMonad m, Ord a) => Int -> Measure a -> [Cond] -> m (Mixture a)
empiricalMeasure !n measure conds = do
gen <- MWC.create
go n gen empty
where once = unMeasure measure conds
go 0 _ m = return m
go k g m = once g >>= \result -> go (k - 1) g $! mappend m (finish result)
sample :: Measure a -> [Cond] -> IO [(a, Prob)]
sample measure conds = do
gen <- MWC.create
unsafeInterleaveIO $ sampleNext gen
where once = unMeasure measure conds
mixToTuple = head . M.toList . unMixture
sampleNext g = do
u <- once g
let x = mixToTuple (finish u)
xs <- unsafeInterleaveIO $ sampleNext g
return (x : xs)
-- u <- once gen
-- let x = mixToTuple (finish u)
-- xs <- unsafeInterleaveIO $ sample measure conds gen
-- return (x : xs)
-- where once = unMeasure measure conds
-- mixToTuple = head . M.toList . unMixture
logit :: Floating a => a -> a
logit !x = 1 / (1 + exp (- x))