module Language.Hakaru.Distribution where
import System.Random
import Language.Hakaru.Mixture
import Language.Hakaru.Types
import Data.Ix
import Data.Maybe (fromMaybe)
import Data.List (findIndex, foldl')
import Numeric.SpecFunctions
import qualified Data.Map.Strict as M
import qualified Data.Number.LogFloat as LF
mapFst :: (t -> s) -> (t, u) -> (s, u)
mapFst f (a,b) = (f a, b)
dirac :: (Eq a) => a -> Dist a
dirac theta = Dist {logDensity = (\ (Discrete x) -> if x == theta then 0 else log 0),
distSample = (\ g -> (Discrete theta,g))}
bern :: Double -> Dist Bool
bern p = Dist {logDensity = (\ (Discrete x) -> log (if x then p else 1 p)),
distSample = (\ g -> case randomR (0, 1) g of
(t, g') -> (Discrete $ t <= p, g'))}
uniform :: Double -> Double -> Dist Double
uniform lo hi =
let uniformLogDensity lo' hi' x | lo' <= x && x <= hi' = log (recip (hi' lo'))
uniformLogDensity _ _ _ = log 0
in Dist {logDensity = (\ (Lebesgue x) -> uniformLogDensity lo hi x),
distSample = (\ g -> mapFst Lebesgue $ randomR (lo, hi) g)}
uniformD :: (Ix a, Random a) => a -> a -> Dist a
uniformD lo hi =
let uniformLogDensity lo' hi' x | lo' <= x && x <= hi' = log density
uniformLogDensity _ _ _ = log 0
density = recip (fromInteger (toInteger (rangeSize (lo,hi))))
in Dist {logDensity = (\ (Discrete x) -> uniformLogDensity lo hi x),
distSample = (\ g -> mapFst Discrete $ randomR (lo, hi) g)}
marsaglia :: (RandomGen g, Random a, Ord a, Floating a) => g -> ((a, a), g)
marsaglia g0 =
let (x, g1) = randomR (1,1) g0
(y, g ) = randomR (1,1) g1
s = x * x + y * y
q = sqrt ((2) * log s / s)
in if 1 >= s && s > 0 then ((x * q, y * q), g) else marsaglia g
choose :: (RandomGen g) => Mixture k -> g -> (k, Prob, g)
choose (Mixture m) g0 =
let peak = maximum (M.elems m)
unMix = M.map (LF.fromLogFloat . (/peak)) m
total = M.foldl' (+) (0::Double) unMix
(p, g) = randomR (0, total) g0
f !k !v b !p0 = let p1 = p0 + v in if p <= p1 then k else b p1
err p0 = error ("choose: failure p0=" ++ show p0 ++
" total=" ++ show total ++
" size=" ++ show (M.size m))
in (M.foldrWithKey f err unMix 0, LF.logFloat total * peak, g)
chooseIndex :: (RandomGen g) => [Double] -> g -> (Int, g)
chooseIndex probs g0 =
let (p, g) = random g0
k = fromMaybe (error ("chooseIndex: failure p=" ++ show p))
(findIndex (p <=) (scanl1 (+) probs))
in (k, g)
normal_rng :: (Real a, Floating a, Random a, RandomGen g) =>
a -> a -> g -> (a, g)
normal_rng mu sd g | sd > 0 = case marsaglia g of
((x, _), g1) -> (mu + sd * x, g1)
normal_rng _ _ _ = error "normal: invalid parameters"
normalLogDensity :: Floating a => a -> a -> a -> a
normalLogDensity mu sd x = (tau * square (x mu)
+ log (tau / pi / 2)) / 2
where square y = y * y
tau = 1 / square sd
normal :: Double -> Double -> Dist Double
normal mu sd = Dist {logDensity = normalLogDensity mu sd . fromLebesgue,
distSample = mapFst Lebesgue . normal_rng mu sd}
categoricalLogDensity :: (Eq b, Floating a) => [(b, a)] -> b -> a
categoricalLogDensity list x = log $ fromMaybe 0 (lookup x list)
categoricalSample :: (Num b, Ord b, RandomGen g, Random b) =>
[(t,b)] -> g -> (t, g)
categoricalSample list g = (elem', g1)
where
(p, g1) = randomR (0, total) g
elem' = fst $ head $ filter (\(_,p0) -> p <= p0) sumList
sumList = scanl1 (\acc (a, b) -> (a, b + snd(acc))) list
total = sum $ map snd list
categorical :: Eq a => [(a,Double)] -> Dist a
categorical list = Dist {logDensity = categoricalLogDensity list . fromDiscrete,
distSample = mapFst Discrete . categoricalSample list}
lnFact :: Integer -> Double
lnFact = logFactorial
poisson_rng :: (RandomGen g) => Double -> g -> (Integer, g)
poisson_rng lambda g0 = make_poisson g0
where smu = sqrt lambda
b = 0.931 + 2.53*smu
a = 0.059 + 0.02483*b
vr = 0.9277 3.6224/(b 2)
arep = 1.1239 + 1.1368/(b3.4)
lnlam = log lambda
make_poisson :: (RandomGen g) => g -> (Integer,g)
make_poisson g = let (u, g1) = randomR (0.5,0.5) g
(v, g2) = randomR (0,1) g1
us = 0.5 abs u
k = floor $ (2*a / us + b)*u + lambda + 0.43 in
case () of
() | us >= 0.07 && v <= vr -> (k, g2)
() | k < 0 -> make_poisson g2
() | us <= 0.013 && v > us -> make_poisson g2
() | accept_region us v k -> (k, g2)
_ -> make_poisson g2
accept_region :: Double -> Double -> Integer -> Bool
accept_region us v k = log (v * arep / (a/(us*us)+b)) <=
lambda + (fromIntegral k)*lnlam lnFact k
poisson :: Double -> Dist Integer
poisson l =
let poissonLogDensity l' x | l' > 0 && x> 0 = (fromIntegral x)*(log l') lnFact x l'
poissonLogDensity l' x | x==0 = l'
poissonLogDensity _ _ = log 0
in Dist {logDensity = poissonLogDensity l . fromDiscrete,
distSample = mapFst Discrete . poisson_rng l}
gamma_rng :: (RandomGen g) => Double -> Double -> g -> (Double, g)
gamma_rng shape _ _ | shape <= 0.0 = error "gamma: got a negative shape paramater"
gamma_rng _ scl _ | scl <= 0.0 = error "gamma: got a negative scale paramater"
gamma_rng shape scl g | shape < 1.0 = (gvar2, g2)
where (gvar1, g1) = gamma_rng (shape + 1) scl g
(w, g2) = randomR (0,1) g1
gvar2 = scl * gvar1 * (w ** recip shape)
gamma_rng shape scl g =
let d = shape 1/3
c = recip $ sqrt $ 9*d
n = normal_rng 1 c
(v, g2) = until (\y -> fst y > 0.0) (\ (_, g') -> normal_rng 1 c g') (n g)
x = (v 1) / c
sqr = x * x
v3 = v * v * v
(u, g3) = randomR (0.0, 1.0) g2
accept = u < 1.0 0.0331*(sqr*sqr) || log u < 0.5*sqr + d*(1.0 v3 + log v3)
in case accept of
True -> (scl*d*v3, g3)
False -> gamma_rng shape scl g3
gammaLogDensity :: Double -> Double -> Double -> Double
gammaLogDensity shape scl x | x>= 0 && shape > 0 && scl > 0 =
scl * log shape scl * x + (shape 1) * log x logGamma shape
gammaLogDensity _ _ _ = log 0
gamma :: Double -> Double -> Dist Double
gamma shape scl = Dist {logDensity = gammaLogDensity shape scl . fromLebesgue,
distSample = mapFst Lebesgue . gamma_rng shape scl}
beta_rng :: (RandomGen g) => Double -> Double -> g -> (Double, g)
beta_rng a b g | a <= 1.0 && b <= 1.0 =
let (u, g1) = randomR (0.0, 1.0) g
(v, g2) = randomR (0.0, 1.0) g1
x = u ** (recip a)
y = v ** (recip b)
in case (x+y) <= 1.0 of
True -> (x / (x + y), g2)
False -> beta_rng a b g2
beta_rng a b g = let (ga, g1) = gamma_rng a 1 g
(gb, g2) = gamma_rng b 1 g1
in (ga / (ga + gb), g2)
betaLogDensity :: Double -> Double -> Double -> Double
betaLogDensity _ _ x | x < 0 || x > 1 = error "beta: value must be between 0 and 1"
betaLogDensity a b _ | a <= 0 || b <= 0 = error "beta: parameters must be positve"
betaLogDensity a b x = (logGamma (a + b)
logGamma a
logGamma b
+ (a 1) * log x
+ (b 1) * log (1 x))
beta :: Double -> Double -> Dist Double
beta a b = Dist {logDensity = betaLogDensity a b . fromLebesgue,
distSample = mapFst Lebesgue . beta_rng a b}
laplace_rng :: (RandomGen g) => Double -> Double -> g -> (Double, g)
laplace_rng mu sd g = sample (randomR (0.0, 1.0) g)
where sample (u, g1) = case u < 0.5 of
True -> (mu + sd * log (u + u), g1)
False -> (mu sd * log (2.0 u u), g1)
laplaceLogDensity :: Floating a => a -> a -> a -> a
laplaceLogDensity mu sd x = log (2 * sd) abs (x mu) / sd
laplace :: Double -> Double -> Dist Double
laplace mu sd = Dist {logDensity = laplaceLogDensity mu sd . fromLebesgue,
distSample = mapFst Lebesgue . laplace_rng mu sd}
dirichlet_rng :: (RandomGen g) => Int -> Double -> g -> ([Double], g)
dirichlet_rng n' a g' = normalize (gammas g' n')
where gammas g 0 = ([], 0, g)
gammas g n = let (xs, total, g1) = gammas g (n1)
( x, g2) = gamma_rng a 1 g1
in ((x : xs), x+total, g2)
normalize (b, total, h) = (map (/ total) b, h)
dirichletLogDensity :: [Double] -> [Double] -> Double
dirichletLogDensity a x | all (> 0) x = sum' (zipWith logTerm a x) + logGamma (sum a)
where sum' = foldl' (+) 0
logTerm b y = (b1) * log y logGamma b
dirichletLogDensity _ _ = error "dirichlet: all values must be between 0 and 1"