module Language.Synthesis.MCMC (mhList) where import Control.Monad.Random (Rand, RandomGen, getRandom, getSplit, runRand) import Data.Functor ((<$>)) import Language.Synthesis.Distribution (Distr) import qualified Language.Synthesis.Distribution as Distr -- These functions work on triples, (value, aux, density). -- Density functions take a value and return auxilary and density. -- |Use the Metropolis-Hastings algorithm to sample a list of values. mhList :: RandomGen g => a -- ^The initial value. -> (a -> (b, Double)) -- ^Density function. -> (a -> Distr a) -- ^Jumping distribution. -> Rand g [(a, b, Double)] -- ^List of (value, aux, density). mhList startValue density jump = go (startValue, startAux, startDensity) <$> getSplit where (startAux, startDensity) = density startValue go orig g = let (next, g') = runRand (mhNext orig) g in orig : go next g' mhNext (orig, origAux, origDensity) = do next <- Distr.sample $ jump orig let origToNext = Distr.logProbability (jump orig) next nextToOrig = Distr.logProbability (jump next) orig (nextAux, nextDensity) = density next score = nextDensity - origDensity + nextToOrig - origToNext acceptance <- getRandom return $ if score >= log acceptance then (next, nextAux, nextDensity) else (orig, origAux, origDensity)