{-# LANGUAGE RankNTypes, NoMonomorphismRestriction, BangPatterns, DeriveDataTypeable, GADTs, ScopedTypeVariables, ExistentialQuantification, StandaloneDeriving #-} module Language.Hakaru.Metropolis where import System.Random (RandomGen, StdGen, randomR, getStdGen) import System.IO import Control.Monad import Data.Dynamic import Data.Function (on) import Data.Maybe import qualified Data.Map.Strict as M import RandomChoice import Visual {- Shortcomings of this implementation * uses parent-conditional sampling for proposal distribution * re-evaluates entire program at every sample * lacks way to block sample groups of variables -} type DistVal = Dynamic data Dist a = Dist {logDensity :: a -> Likelihood, sample :: forall g. RandomGen g => g -> (a, g)} deriving instance Typeable1 Dist data XRP = forall e. Typeable e => XRP (e, Dist e) unXRP :: Typeable a => XRP -> Maybe (a, Dist a) unXRP (XRP (e,f)) = cast (e,f) type Var a = Int type Likelihood = Double type Visited = Bool type Observed = Bool type Cond = Maybe DistVal type Subloc = Int type Name = [Subloc] type Database = M.Map Name (XRP, Likelihood, Visited, Observed) newtype Measure a = Measure {unMeasure :: (RandomGen g) => (Name ,Database ,(Likelihood, Likelihood) ,[Cond] ,g ) -> (a ,Database ,(Likelihood, Likelihood) ,[Cond] ,g)} deriving (Typeable) -- n is structural_name -- d is database -- ll is likelihood of expression -- conds is the observed data -- g is the random seed lit :: (Eq a, Typeable a) => a -> a lit = id return_ :: a -> Measure a return_ x = Measure (\ (n, d, l, conds, g) -> (x, d, l, conds, g)) makeXRP :: (Typeable a, RandomGen g) => Cond -> Dist a -> Name -> Database -> g -> (a, Database, Likelihood, Likelihood, g) makeXRP obs dist' n db g = case M.lookup n db of Just (xd, lb, b, ob) -> let Just (xb, dist) = unXRP xd (x,l) = case obs of Just xd -> let Just x = fromDynamic xd in (x, logDensity dist x) Nothing -> (xb, lb) l' = logDensity dist' x d1 = M.insert n (XRP (x,dist), l', True, ob) db in (x, d1, l', 0, g) Nothing -> let (xnew, l, g1) = case obs of Just xdnew -> let Just xnew = fromDynamic xdnew in (xnew, logDensity dist' xnew, g) Nothing -> (xnew, logDensity dist' xnew, g1) where (xnew, g1) = sample dist' g d1 = M.insert n (XRP (xnew, dist'), l, True, isJust obs) db in (xnew, d1, l, l, g1) updateLikelihood :: (Typeable a, RandomGen g) => Likelihood -> Likelihood -> (a, Database, Likelihood, Likelihood, g) -> [Cond] -> (a, Database, (Likelihood, Likelihood), [Cond], g) updateLikelihood llTotal llFresh (x,d,l,lf,g) conds = (x, d, (llTotal+l, llFresh+lf), conds, g) dirac :: (Eq a, Typeable a) => a -> Cond -> Measure a dirac theta obs = Measure $ \(n, d, (llTotal,llFresh), conds, g) -> let dist' = Dist {logDensity = (\ x -> if x == theta then 0 else log 0), sample = (\ g -> (theta,g))} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds bern :: Double -> Cond -> Measure Bool bern p obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let dist' = Dist {logDensity = (\ x -> log (if x then p else 1 - p)), sample = (\ g -> case randomR (0, 1) g of (t, g') -> (t <= p, g'))} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds poisson :: Double -> Cond -> Measure Int poisson l obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let poissonLogDensity l x | l > 0 && x> 0 = (fromIntegral x)*(log l) - lnFact x - l poissonLogDensity l x | x==0 = -l poissonLogDensity _ _ = log 0 dist' = Dist {logDensity = poissonLogDensity l, sample = poisson_rng l} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds gamma :: Double -> Double -> Cond -> Measure Double gamma shape scale obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let dist' = Dist {logDensity = gammaLogDensity shape scale, sample = gamma_rng shape scale} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds beta :: Double -> Double -> Cond -> Measure Double beta a b obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let dist' = Dist {logDensity = betaLogDensity a b, sample = beta_rng a b} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds uniform :: Double -> Double -> Cond -> Measure Double uniform lo hi obs = Measure $ \(n, d, (llTotal,llFresh), conds, g) -> let uniformLogDensity lo hi x | lo <= x && x <= hi = log (recip (hi - lo)) uniformLogDensity _ _ x = log 0 dist' = Dist {logDensity = uniformLogDensity lo hi, sample = (\ g -> randomR (lo, hi) g)} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds normal :: Double -> Double -> Cond -> Measure Double normal mu sd obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let dist' = Dist {logDensity = normalLogDensity mu sd, sample = normal_rng mu sd} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds laplace :: Double -> Double -> Cond -> Measure Double laplace mu sd obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let dist' = Dist {logDensity = laplaceLogDensity mu sd, sample = laplace_rng mu sd} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds categorical :: (Eq a, Typeable a) => [(a,Double)] -> Cond -> Measure a categorical list obs = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> let categoricalLogDensity list x = log $ fromMaybe 0 (lookup x list) categoricalSample list g = (elem, g1) where (p, g1) = randomR (0, total) g elem = fst $ head $ filter (\(_,p0) -> p <= p0) sumList sumList = scanl1 (\acc (a, b) -> (a, b + snd(acc))) list total = sum $ map snd list dist' = Dist {logDensity = categoricalLogDensity list, sample = categoricalSample list} xrp = makeXRP obs dist' n d g in updateLikelihood llTotal llFresh xrp conds factor :: Likelihood -> Measure () factor l = Measure $ \(n, d, (llTotal, llFresh), conds, g) -> ((), d, (llTotal + l, llFresh), conds, g) resample :: RandomGen g => XRP -> g -> (XRP, Likelihood, Likelihood, Likelihood, g) resample (XRP (x, dist)) g = let (x', g1) = sample dist g fwd = logDensity dist x' rvs = logDensity dist x l' = fwd in (XRP (x', dist), l', fwd, rvs, g1) bind :: Measure a -> (a -> Measure b) -> Measure b bind (Measure m) cont = Measure $ \ (n,d,ll,conds,g) -> let (v, d1, ll1, conds1, g1) = m (0:n, d, ll, conds, g) in unMeasure (cont v) (1:n, d1, ll1, conds1, g1) conditioned :: (Cond -> Measure a) -> Measure a conditioned f = Measure $ \ (n,d,ll,cond:conds,g) -> unMeasure (f cond) (n, d, ll, conds, g) unconditioned :: (Cond -> Measure a) -> Measure a unconditioned f = f Nothing instance Monad Measure where return = return_ (>>=) = bind lam :: (a -> b) -> (a -> b) lam f = f app :: (a -> b) -> a -> b app f x = f x fix :: ((a -> b) -> (a -> b)) -> (a -> b) fix g = f where f = g f ifThenElse :: Bool -> a -> a -> a ifThenElse True t _ = t ifThenElse False _ f = f run :: Measure a -> [Cond] -> IO (a, Database, Likelihood) run (Measure prog) conds = do g <- getStdGen let (v, d, ll, conds1, g') = prog ([0], M.empty, (0,0), conds, g) return (v, d, fst ll) traceUpdate :: RandomGen g => Measure a -> Database -> [Cond] -> g -> (a, Database, Likelihood, Likelihood, Likelihood, g) traceUpdate (Measure prog) d conds g = do let d1 = M.map (\ (x, l, _, ob) -> (x, l, False, ob)) d let (v, d2, (llTotal, llFresh), conds1, g1) = prog ([0], d1, (0,0), conds, g) let (d3, stale_d) = M.partition (\ (_, _, v, _) -> v) d2 let llStale = M.foldl' (\ llStale (_,l,_,_) -> llStale + l) 0 stale_d (v, d3, llTotal, llFresh, llStale, g1) initialStep :: Measure a -> [Cond] -> IO (a, Database, Likelihood, Likelihood, Likelihood, StdGen) initialStep prog conds = do g <- getStdGen return $ traceUpdate prog M.empty conds g -- TODO: Make a way of passing user-provided proposal distributions updateDB :: (RandomGen g) => Name -> Database -> Observed -> XRP -> g -> (Database, Likelihood, Likelihood, Likelihood, g) updateDB name db ob xd g = (db', l', fwd, rvs, g) where db' = M.insert name (x', l', True, ob) db (x', l', fwd, rvs, g1) = resample xd g transition :: (Typeable a, RandomGen g) => Measure a -> [Cond] -> a -> Database -> Likelihood -> g -> [a] transition prog conds v db ll g = let dbSize = M.size db -- choose an unconditioned choice (condDb, uncondDb) = M.partition (\ (_, _, _, ob) -> ob) db (choice, g1) = randomR (0, (M.size uncondDb) -1) g (name, (xd, l, _, ob)) = M.elemAt choice uncondDb (db', l', fwd, rvs, g2) = updateDB name db ob xd g1 (v', db2, llTotal, llFresh, llStale, g3) = traceUpdate prog db' conds g2 a = llTotal - ll + rvs - fwd + log (fromIntegral dbSize) - log (fromIntegral $ M.size db2) + llStale - llFresh (u, g4) = randomR (0 :: Double, 1) g3 in if (log u < a) then v' : (transition prog conds v' db2 llTotal g4) else v : (transition prog conds v db ll g4) mcmc :: Typeable a => Measure a -> [Cond] -> IO [a] mcmc prog conds = do (v, d, llTotal, llFresh, llStale, g) <- initialStep prog conds return $ transition prog conds v d llTotal g