{-# LANGUAGE RankNTypes, NoMonomorphismRestriction, BangPatterns #-} {-# OPTIONS -W #-} module Language.Hakaru.ImportanceSampler where -- This is an interpreter that's like Interpreter except conditioning is -- checked at run time rather than by static types. In other words, we allow -- models to be compiled whose conditioned parts do not match the observation -- inputs. In exchange, we get to make Measure an instance of Monad, and we -- can express models whose number of observations is unknown at compile time. import Types (Cond(..), CSampler(CSampler)) import RandomChoice (normal_rng, chooseIndex) import Mixture (Prob, empty, point, Mixture(..)) import Sampler (Sampler, deterministic, smap, sbind) import System.Random import Data.Monoid import Data.Ix import Data.Dynamic import Data.List import Control.Monad import qualified Data.Map.Strict as M import qualified Data.Number.LogFloat as LF dirac :: (Eq a, Typeable a) => a -> CSampler a dirac x = CSampler c where c Unconditioned = deterministic (point x 1) c (Discrete y) = case fromDynamic y of Just y -> deterministic (if x == y then point x 1 else empty) Nothing -> error "dirac: did not get data from dynamic source" c _ = error "dirac: got a non-discrete sampler" bern :: Double -> CSampler Bool bern theta | 0 <= theta && theta <= 1 = CSampler c where c Unconditioned = \g0 -> case randomR (0, 1) g0 of (x, g) -> (point (x <= theta) 1, g) c (Discrete y) = case fromDynamic y of Just y -> deterministic (point y (LF.logFloat (if y then theta else 1 - theta))) Nothing -> error "bern: did not get data from dynamic source" c _ = error "bern: got a non-discrete sampler" bern theta = error ("bernoulli: invalid parameter " ++ show theta) uniformC :: (Fractional a, Real a, Random a, Typeable a) => a -> a -> CSampler a uniformC lo hi | lo < hi = CSampler c where c Unconditioned = \g0 -> case randomR (lo,hi) g0 of (x, g) -> (point x 1, g) c (Lebesgue y) = case fromDynamic y of Just y -> deterministic (if lo < y && y < hi then point y density else empty) Nothing -> error "uniformC: did not get data from dynamic source" c _ = error "uniformC: got a discrete sampler" density = fromRational (toRational (recip (hi - lo))) uniformC _ _ = error "uniformC: invalid parameters" uniformD :: (Ix a, Random a, Typeable a) => a -> a -> CSampler a uniformD lo hi | lo <= hi = CSampler c where c Unconditioned = \g0 -> case randomR (lo,hi) g0 of (x, g) -> (point x 1, g) c (Discrete y) = case fromDynamic y of Just y -> deterministic (if lo <= y && y <= hi then point y density else empty) Nothing -> error "uniformD: did not get data from dynamic source" c _ = error "uniformD: got a non-discrete sampler" density = recip (fromInteger (toInteger (rangeSize (lo,hi)))) uniformD _ _ = error "uniformD: invalid parameters" poisson :: (Integral a, Typeable a) => Double -> CSampler a poisson !l | 0 <= l = CSampler c where c Unconditioned = \g0 -> let probs = exp (-l) : zipWith (\k p -> p * l / k) [1..] probs (k, g) = chooseIndex probs g0 in (point (fromInteger (toInteger k)) 1, g) c (Discrete k) = case fromDynamic k of Just k -> deterministic (if 0 <= k then point k (LF.logToLogFloat (-l) * LF.logFloat l ^ k / product (map fromIntegral [1..k])) else empty) Nothing -> error "poisson: did not get data from dynamic source" c _ = error "poisson: got a non-discrete sampler" poisson _ = error "poisson: invalid parameter" normal :: (Real a, Floating a, Random a, Typeable a) => a -> a -> CSampler a normal !mean !std | std > 0 = CSampler c where c Unconditioned = \g0 -> let (x, g) = normal_rng mean std g0 in (point (mean + std * x) 1, g) c (Lebesgue y) = case fromDynamic y of Just y -> let density = exp (square ((y - mean) / std) / (-2)) / std / sqrt (2 * pi) square y = y * y in deterministic (point y (fromRational (toRational density))) -- TODO: use log-density and LogFloat directly Nothing -> error "normal: did not get data from dynamic source" c _ = error "normal: got a discrete sampler" normal _ _ = error "normal: invalid parameters" categorical :: (Typeable a, Eq a) => [(a, Prob)] -> CSampler a categorical list = CSampler c where peak :: LF.LogFloat peak = maximum (map snd list) total :: Double (total, list') = mapAccumL f 0 list where f acc (a,b) = (acc', (a, (b', acc'))) where b' = b/peak acc' :: Double acc' = acc + LF.fromLogFloat b' c Unconditioned = \g0 -> let (p, g) = randomR (0, total) g0 (elem, _) : _ = filter (\(_,(_,p0)) -> p <= p0) list' in (point elem 1, g) c (Discrete y) = case fromDynamic y of Just y -> deterministic (maybe empty (point y . (/ LF.logFloat total) . fst) (lookup y list')) Nothing -> error "categorical: did not get data from dynamic source" c _ = error "categorical: got a non-discrete sampler" -- Conditioned sampling newtype Measure a = Measure { unMeasure :: [Cond] -> Sampler (a, [Cond]) } bind :: Measure a -> (a -> Measure b) -> Measure b bind measure continuation = Measure (\conds -> sbind (unMeasure measure conds) (\(a,conds) -> unMeasure (continuation a) conds)) instance Monad Measure where return x = Measure (\conds -> deterministic (point (x,conds) 1)) (>>=) = bind conditioned, unconditioned :: CSampler a -> Measure a conditioned (CSampler f) = Measure (\(cond:conds) -> smap (\a->(a,conds)) (f cond )) unconditioned (CSampler f) = Measure (\ conds -> smap (\a->(a,conds)) (f Unconditioned)) factor :: Prob -> Measure () factor p = Measure (\conds -> deterministic (point ((), conds) p)) -- Our language also includes the usual goodies of a lambda calculus var :: a -> a var = id lit :: a -> a lit = id lam :: (a -> b) -> (a -> b) lam f = f app :: (a -> b) -> a -> b app f x = f x fix :: ((a -> b) -> (a -> b)) -> (a -> b) fix g = f where f = g f ifThenElse :: Bool -> a -> a -> a ifThenElse True t _ = t ifThenElse False _ e = e -- Drivers for testing finish :: Mixture (a, [Cond]) -> Mixture a finish (Mixture m) = Mixture (M.mapKeysMonotonic (\(a,[]) -> a) m) sample :: (Ord a) => Int -> Measure a -> [Cond] -> IO (Mixture a) sample !n measure conds = go n empty where once = getStdRandom (unMeasure measure conds) go 0 m = return m go n m = once >>= \result -> go (n - 1) $! mappend m (finish result) sample_ :: (Ord a, Show a) => Int -> Measure a -> [Cond] -> IO () sample_ !n measure conds = replicateM_ n (once >>= pr) where once = getStdRandom (unMeasure measure conds) pr = print . finish logit :: Floating a => a -> a logit !x = 1 / (1 + exp (- x))