{-# LANGUAGE RankNTypes, BangPatterns, GADTs #-} {-# OPTIONS -Wall #-} module Language.Hakaru.Distribution where import Control.Monad import Control.Monad.Primitive import Control.Monad.Loops import qualified System.Random.MWC as MWC import Language.Hakaru.Mixture import Language.Hakaru.Types import Data.Ix import Data.Maybe (fromMaybe) import Data.List (findIndex, foldl') import Numeric.SpecFunctions import qualified Data.Map.Strict as M import qualified Data.Number.LogFloat as LF mapFst :: (t -> s) -> (t, u) -> (s, u) mapFst f (a,b) = (f a, b) dirac :: (Eq a) => a -> Dist a dirac theta = Dist {logDensity = (\ (Discrete x) -> if x == theta then 0 else log 0), distSample = (\ _ -> return $ Discrete theta)} bern :: Double -> Dist Bool bern p = Dist {logDensity = (\ (Discrete x) -> log (if x then p else 1 - p)), distSample = (\ g -> do t <- MWC.uniformR (0,1) g return $ Discrete (t <= p))} uniform :: Double -> Double -> Dist Double uniform lo hi = let uniformLogDensity lo' hi' x | lo' <= x && x <= hi' = log (recip (hi' - lo')) uniformLogDensity _ _ _ = log 0 in Dist {logDensity = (\ (Lebesgue x) -> uniformLogDensity lo hi x), distSample = (\ g -> liftM Lebesgue $ MWC.uniformR (lo, hi) g)} uniformD :: (Ix a, MWC.Variate a) => a -> a -> Dist a uniformD lo hi = let uniformLogDensity lo' hi' x | lo' <= x && x <= hi' = log density uniformLogDensity _ _ _ = log 0 density = recip (fromInteger (toInteger (rangeSize (lo,hi)))) in Dist {logDensity = (\ (Discrete x) -> uniformLogDensity lo hi x), distSample = (\ g -> liftM Discrete $ MWC.uniformR (lo, hi) g)} marsaglia :: (MWC.Variate a, Ord a, Floating a, PrimMonad m) => PRNG m -> m (a, a) marsaglia g = do -- "Marsaglia polar method" x <- MWC.uniformR (-1,1) g y <- MWC.uniformR (-1,1) g let s = x * x + y * y q = sqrt ((-2) * log s / s) if 1 >= s && s > 0 then return (x * q, y * q) else marsaglia g choose :: (PrimMonad m) => Mixture k -> PRNG m -> m (k, Prob) choose (Mixture m) g = do let peak = maximum (M.elems m) unMix = M.map (LF.fromLogFloat . (/peak)) m total = M.foldl' (+) (0::Double) unMix p <- MWC.uniformR (0, total) g let f !k !v b !p0 = let p1 = p0 + v in if p <= p1 then k else b p1 err p0 = error ("choose: failure p0=" ++ show p0 ++ " total=" ++ show total ++ " size=" ++ show (M.size m)) return $ (M.foldrWithKey f err unMix 0, LF.logFloat total * peak) chooseIndex :: (PrimMonad m) => [Double] -> PRNG m -> m Int chooseIndex probs g = do p <- MWC.uniform g return $ fromMaybe (error ("chooseIndex: failure p=" ++ show p)) (findIndex (p <=) (scanl1 (+) probs)) normal_rng :: (Real a, Floating a, MWC.Variate a, PrimMonad m) => a -> a -> PRNG m -> m a normal_rng mu sd g | sd > 0 = do (x, _) <- marsaglia g return (mu + sd * x) normal_rng _ _ _ = error "normal: invalid parameters" normalLogDensity :: Floating a => a -> a -> a -> a normalLogDensity mu sd x = (-tau * square (x - mu) + log (tau / pi / 2)) / 2 where square y = y * y tau = 1 / square sd normal :: Double -> Double -> Dist Double normal mu sd = Dist {logDensity = normalLogDensity mu sd . fromLebesgue, distSample = (\g -> liftM Lebesgue $ normal_rng mu sd g)} categoricalLogDensity :: (Eq b, Floating a) => [(b, a)] -> b -> a categoricalLogDensity list x = log $ fromMaybe 0 (lookup x list) categoricalSample :: (Num b, Ord b, PrimMonad m, MWC.Variate b) => [(t,b)] -> PRNG m -> m t categoricalSample list g = do let total = sum $ map snd list p <- MWC.uniformR (0, total) g let sumList = scanl1 (\acc (a, b) -> (a, b + snd(acc))) list elem' = fst $ head $ filter (\(_,p0) -> p <= p0) sumList return elem' categorical :: Eq a => [(a,Double)] -> Dist a categorical list = Dist {logDensity = categoricalLogDensity list . fromDiscrete, distSample = (\g -> liftM Discrete $ categoricalSample list g)} lnFact :: Int -> Double lnFact = logFactorial -- Makes use of Atkinson's algorithm as described in: -- Monte Carlo Statistical Methods pg. 55 -- -- Further discussion at: -- http://www.johndcook.com/blog/2010/06/14/generating-poisson-random-values/ poisson_rng :: (PrimMonad m) => Double -> PRNG m -> m Int poisson_rng lambda g' = make_poisson g' where smu = sqrt lambda b = 0.931 + 2.53*smu a = -0.059 + 0.02483*b vr = 0.9277 - 3.6224/(b - 2) arep = 1.1239 + 1.1368/(b-3.4) lnlam = log lambda make_poisson :: (PrimMonad m) => PRNG m -> m Int make_poisson g = do u <- MWC.uniformR (-0.5,0.5) g v <- MWC.uniformR (0,1) g let us = 0.5 - abs u k = floor $ (2*a / us + b)*u + lambda + 0.43 case () of () | us >= 0.07 && v <= vr -> return k () | k < 0 -> make_poisson g () | us <= 0.013 && v > us -> make_poisson g () | accept_region us v k -> return k _ -> make_poisson g accept_region :: Double -> Double -> Int -> Bool accept_region us v k = log (v * arep / (a/(us*us)+b)) <= -lambda + (fromIntegral k)*lnlam - lnFact k poisson :: Double -> Dist Int poisson l = let poissonLogDensity l' x | l' > 0 && x> 0 = (fromIntegral x)*(log l') - lnFact x - l' poissonLogDensity l' x | x==0 = -l' poissonLogDensity _ _ = log 0 in Dist {logDensity = poissonLogDensity l . fromDiscrete, distSample = (\g -> liftM Discrete $ poisson_rng l g)} -- Direct implementation of "A Simple Method for Generating Gamma Variables" -- by George Marsaglia and Wai Wan Tsang. gamma_rng :: (PrimMonad m) => Double -> Double -> PRNG m -> m Double gamma_rng shape _ _ | shape <= 0.0 = error "gamma: got a negative shape paramater" gamma_rng _ scl _ | scl <= 0.0 = error "gamma: got a negative scale paramater" gamma_rng shape scl g | shape < 1.0 = do gvar1 <- gamma_rng (shape + 1) scl g w <- MWC.uniformR (0,1) g return $ scl * gvar1 * (w ** recip shape) gamma_rng shape scl g = do let d = shape - 1/3 c = recip $ sqrt $ 9*d -- Algorithm recommends inlining normal generator -- n = normal_rng 1 c v <- iterateUntil (> 0.0) $ normal_rng 1 c g -- (v, g2) = until (\y -> fst y > 0.0) (\ (_, g') -> normal_rng 1 c g') (n g) let x = (v - 1) / c sqr = x * x v3 = v * v * v u <- MWC.uniformR (0.0, 1.0) g let accept = u < 1.0 - 0.0331*(sqr*sqr) || log u < 0.5*sqr + d*(1.0 - v3 + log v3) case accept of True -> return $ scl*d*v3 False -> gamma_rng shape scl g gammaLogDensity :: Double -> Double -> Double -> Double gammaLogDensity shape scl x | x>= 0 && shape > 0 && scl > 0 = scl * log shape - scl * x + (shape - 1) * log x - logGamma shape gammaLogDensity _ _ _ = log 0 gamma :: Double -> Double -> Dist Double gamma shape scl = Dist {logDensity = gammaLogDensity shape scl . fromLebesgue, distSample = (\g -> liftM Lebesgue $ gamma_rng shape scl g)} beta_rng :: (PrimMonad m) => Double -> Double -> PRNG m -> m Double beta_rng a b g | a <= 1.0 && b <= 1.0 = do u <- MWC.uniformR (0.0, 1.0) g v <- MWC.uniformR (0.0, 1.0) g let x = u ** (recip a) y = v ** (recip b) case (x+y) <= 1.0 of True -> return $ x / (x + y) False -> beta_rng a b g beta_rng a b g = do ga <- gamma_rng a 1 g gb <- gamma_rng b 1 g return $ ga / (ga + gb) betaLogDensity :: Double -> Double -> Double -> Double betaLogDensity _ _ x | x < 0 || x > 1 = error "beta: value must be between 0 and 1" betaLogDensity a b _ | a <= 0 || b <= 0 = error "beta: parameters must be positve" betaLogDensity a b x = (logGamma (a + b) - logGamma a - logGamma b + (a - 1) * log x + (b - 1) * log (1 - x)) beta :: Double -> Double -> Dist Double beta a b = Dist {logDensity = betaLogDensity a b . fromLebesgue, distSample = (\g -> liftM Lebesgue $ beta_rng a b g)} laplace_rng :: (PrimMonad m) => Double -> Double -> PRNG m -> m Double laplace_rng mu sd g = MWC.uniformR (0.0, 1.0) g >>= sample where sample u = return $ case u < 0.5 of True -> mu + sd * log (u + u) False -> mu - sd * log (2.0 - u - u) laplaceLogDensity :: Floating a => a -> a -> a -> a laplaceLogDensity mu sd x = - log (2 * sd) - abs (x - mu) / sd laplace :: Double -> Double -> Dist Double laplace mu sd = Dist {logDensity = laplaceLogDensity mu sd . fromLebesgue, distSample = (\g -> liftM Lebesgue $ laplace_rng mu sd g)} -- Consider having dirichlet return Vector -- Note: This is actually symmetric dirichlet dirichlet_rng :: (PrimMonad m) => Int -> Double -> PRNG m -> m [Double] dirichlet_rng n' a g' = liftM normalize $ gammas g' n' where gammas _ 0 = return ([], 0) gammas g n = do (xs, total) <- gammas g (n-1) x <- gamma_rng a 1 g return ((x : xs), x+total) normalize (b, total) = map (/ total) b dirichletLogDensity :: [Double] -> [Double] -> Double dirichletLogDensity a x | all (> 0) x = sum' (zipWith logTerm a x) + logGamma (sum a) where sum' = foldl' (+) 0 logTerm b y = (b-1) * log y - logGamma b dirichletLogDensity _ _ = error "dirichlet: all values must be between 0 and 1"