{-# LANGUAGE RankNTypes, NoMonomorphismRestriction, BangPatterns, DeriveDataTypeable, GADTs, ScopedTypeVariables, ExistentialQuantification, StandaloneDeriving #-} {-# OPTIONS -Wall #-} module Language.Hakaru.Metropolis where import qualified System.Random.MWC as MWC import Control.Monad import Control.Monad.Primitive import Data.Dynamic import Data.Maybe import Control.Applicative import qualified Data.Map.Strict as M import Language.Hakaru.Types import System.IO.Unsafe {- Shortcomings of this implementation * uses parent-conditional sampling for proposal distribution * re-evaluates entire program at every sample * lacks way to block sample groups of variables -} type DistVal = Dynamic data XRP where XRP :: Typeable e => (Density e, Dist e) -> XRP unXRP :: Typeable a => XRP -> Maybe (Density a, Dist a) unXRP (XRP (e,f)) = cast (e,f) type Visited = Bool type Observed = Bool type LL = LogLikelihood -- The first component is the LogLikelihood of the trace -- The second is the LogLikelihood of the newly introduced -- choices. These are used to compute the acceptance ratio type LL2 = (LL,LL) type Subloc = Int type Name = [Subloc] data DBEntry = DBEntry { xrp :: XRP, llhd :: LL, vis :: Visited, observed :: Observed } type Database = M.Map Name DBEntry data SamplerState where S :: { ldb :: Database, -- ldb = local database -- (total likelihood, total likelihood of XRPs newly introduced) llh2 :: {-# UNPACK #-} !LL2, cnds :: [Cond] -- conditions left to process } -> SamplerState type Sampler a = PrimMonad m => SamplerState -> PRNG m -> m (a, SamplerState) sreturn :: a -> Sampler a sreturn x s _ = return (x, s) sbind :: Sampler a -> (a -> Sampler b) -> Sampler b sbind s k = \ st g -> do (v, s') <- s st g k v s' g smap :: (a -> b) -> Sampler a -> Sampler b smap f s = sbind s (\a -> sreturn (f a)) newtype Measure a = Measure {unMeasure :: Name -> Sampler a } deriving (Typeable) return_ :: a -> Measure a return_ x = Measure $ \ _ -> sreturn x updateXRP :: Typeable a => Name -> Cond -> Dist a -> Sampler a updateXRP n obs dist' s@(S {ldb = db}) g = do case M.lookup n db of Just (DBEntry xd _ _ ob) -> do let Just (x, _) = unXRP xd l' = logDensity dist' x d1 = M.insert n (DBEntry (XRP (x,dist')) l' True ob) db return (fromDensity x, s {ldb = d1, llh2 = updateLogLikelihood (l',0) (llh2 s)}) Nothing -> do (xnew2, l) <- case obs of Just xdnew -> do let Just xnew = fromDynamic xdnew return $ (xnew, logDensity dist' xnew) Nothing -> do xnew <- distSample dist' g return (xnew, logDensity dist' xnew) let d1 = M.insert n (DBEntry (XRP (xnew2, dist')) l True (isJust obs)) db return (fromDensity xnew2, s {ldb = d1, llh2 = updateLogLikelihood (l,l) (llh2 s)}) updateLogLikelihood :: LL2 -> LL2 -> LL2 updateLogLikelihood (llTotal,llFresh) (l,lf) = (llTotal+l, llFresh+lf) factor :: LL -> Measure () factor l = Measure $ \ _ -> \ s _ -> do let (llTotal, llFresh) = llh2 s return ((), s {llh2 = (llTotal + l, llFresh)}) condition :: Eq b => Measure (a, b) -> b -> Measure a condition (Measure m) b' = Measure $ \ n -> do let comp a b s | a /= b = s {llh2 = (log 0, 0)} comp _ _ s = s sbind (m n) (\ (a, b) s _ -> return (a, comp b b' s)) bind :: Measure a -> (a -> Measure b) -> Measure b bind (Measure m) cont = Measure $ \ n -> sbind (m (0:n)) (\ a -> unMeasure (cont a) (1:n)) conditioned :: Typeable a => Dist a -> Measure a conditioned dist = Measure $ \ n -> \s@(S {cnds = cond:conds }) -> updateXRP n cond dist s{cnds = conds} unconditioned :: Typeable a => Dist a -> Measure a unconditioned dist = Measure $ \ n -> updateXRP n Nothing dist instance Monad Measure where return = return_ (>>=) = bind instance Functor Measure where fmap f (Measure x) = Measure $ \n -> smap f (x n) instance Applicative Measure where pure = return_ (<*>) = app sapp :: (Sampler (a -> b)) -> Sampler a -> Sampler b sapp f s = \st g -> do (vf, s') <- f st g (vs, s'') <- s s' g sreturn (vf vs) s'' g app :: Measure (a -> b) -> Measure a -> Measure b app (Measure f) (Measure a) = Measure $ \n -> sapp (f n) (a n) run :: Measure a -> [Cond] -> IO (a, Database, LL) run (Measure prog) cds = do g <- MWC.create (v, S d ll []) <- (prog [0]) (S M.empty (0,0) cds) g return (v, d, fst ll) traceUpdate :: PrimMonad m => Measure a -> Database -> [Cond] -> PRNG m -> m (a, Database, LL, LL, LL) traceUpdate (Measure prog) d cds g = do -- let d1 = M.map (\ (x, l, _, ob) -> (x, l, False, ob)) d let d1 = M.map (\ s -> s { vis = False }) d (v, S d2 (llTotal, llFresh) []) <- (prog [0]) (S d1 (0,0) cds) g let (d3, stale_d) = M.partition vis d2 let llStale = M.foldl' (\ llStale' s -> llStale' + llhd s) 0 stale_d return (v, d3, llTotal, llFresh, llStale) initialStep :: Measure a -> [Cond] -> PRNG IO -> IO (a, Database, LL, LL, LL) initialStep prog cds g = traceUpdate prog M.empty cds g -- TODO: Make a way of passing user-provided proposal distributions resample :: PrimMonad m => Name -> Database -> Observed -> XRP -> PRNG m -> m (Database, LL, LL, LL) resample name db ob (XRP (x, dist)) g = do x' <- distSample dist g let fwd = logDensity dist x' rvs = logDensity dist x l' = fwd newEntry = DBEntry (XRP (x', dist)) l' True ob db' = M.insert name newEntry db return (db', l', fwd, rvs) transition :: (Typeable a) => Measure a -> [Cond] -> a -> Database -> LL -> PRNG IO -> IO [a] transition prog cds v db ll g = do let dbSize = M.size db -- choose an unconditioned choice (_, uncondDb) = M.partition observed db choice <- MWC.uniformR (0, (M.size uncondDb) -1) g let (name, (DBEntry xd _ _ ob)) = M.elemAt choice uncondDb (db', _, fwd, rvs) <- resample name db ob xd g (v', db2, llTotal, llFresh, llStale) <- traceUpdate prog db' cds g let a = llTotal - ll + rvs - fwd + log (fromIntegral dbSize) - log (fromIntegral $ M.size db2) + llStale - llFresh u <- MWC.uniformR (0 :: Double, 1) g if (log u < a) then liftM ((:) v') $ unsafeInterleaveIO (transition prog cds v' db2 llTotal g) else liftM ((:) v) $ unsafeInterleaveIO (transition prog cds v db ll g) mcmc :: Typeable a => Measure a -> [Cond] -> IO [a] mcmc prog cds = do g <- MWC.create (v, d, llTotal, _, _) <- initialStep prog cds g transition prog cds v d llTotal g sample :: Typeable a => Measure a -> [Cond] -> IO [(a, Double)] sample prog cds = do g <- MWC.create (v, d, llTotal, _, _) <- initialStep prog cds g (transition prog cds v d llTotal g) >>= return . map (\ x -> (x,1))