```{-# LANGUAGE RankNTypes #-}
module Language.Synthesis.Distribution (
Distr (Distr), sample, logProbability, negativeInfinity, sumByLogs,
categorical, uniform, randInt, replicate, mix, constant
) where

import           Prelude              hiding (replicate)

import           Control.Monad.Random (Rand, Random, RandomGen, getRandom,
getRandomR)

-- |Represents a discrete probability distribution.
data Distr a = Distr {
-- |Sample a random item from the distribution.
sample         :: forall g. RandomGen g => Rand g a,
-- |Compute the log probability of a given value.
logProbability :: a -> Double
}

-- |Negative infinity, the log of 0 probability.
negativeInfinity :: Double

-- |A distribution containing a single item.
constant :: a -> Distr a
constant item = Distr { sample = return item
, logProbability = const 0 }

-- |Computes (log . sum . map exp), with more numeric precision.
sumByLogs :: [Double] -> Double
sumByLogs xs = log (sum [exp (x - high) | x <- xs]) + high
where high = maximum xs

-- |Samples from an (item, weight) list.
sampleCategorical :: RandomGen g => [(a, Double)] -> Rand g a
sampleCategorical items = go items . sum \$ map snd items
where go [] _ = error "Cannot sample from an empty list."
go [(x, _)] _ = return x
go ((x, weight):rest) total = do
acceptance <- getRandom
if weight/total >= acceptance
then return x
else go rest \$ total - weight

-- |A distribution from an (item, weight) list.
categorical :: Eq a => [(a, Double)] -> Distr a
categorical items = Distr (sampleCategorical items) logProb
where total = sum (map snd items)
logProb item = case lookup item items of
Nothing -> negativeInfinity
Just weight -> log (weight/total)

-- |Uniform distribution.
uniform :: Eq a => [a] -> Distr a
uniform xs = categorical [(x, 1.0) | x <- xs]

-- |A distribution over some integral type, inclusively between the 2
-- values.
randInt :: (Integral i, Random i) => (i, i) -> Distr i
randInt range@(low, high) = Distr (getRandomR range) logProb
where logProb item = if low <= item && item <= high
then -log (fromIntegral \$ high - low + 1)
else negativeInfinity

-- |Generate n independent draws from a distribution.
replicate :: Int -> Distr a -> Distr [a]
replicate n orig = Distr (replicateM n \$ sample orig) logProb
where logProb xs = sum [logProbability orig x | x <- xs]

-- |Given (distribution, weight) pairs, mix the distributions.
mix :: [(Distr a, Double)] -> Distr a
mix distrs = Distr (sampleCategorical distrs >>= sample) logProb
where logProb x = sumByLogs [log prob + logProbability distr x |
(distr, prob) <- distrs]
```