{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Yggdrasil.Distribution
( Distribution
, DistributionT(DistributionT, runDistT)
, Sampler(..)
, liftDistribution
, coin
, uniform
) where
import Control.Monad (ap, (>=>))
import Control.Monad.Trans.Class (MonadTrans (lift))
import Crypto.Random (SystemDRG, randomBytesGenerate)
import Data.Bits ((.&.))
import qualified Data.ByteArray as B
import Data.Functor.Identity (Identity (Identity), runIdentity)
import Data.Maybe (fromJust)
type Distribution = DistributionT Identity
newtype DistributionT m b = DistributionT
{ runDistT :: forall s. Sampler s =>
s -> m (b, s)
}
instance Monad m => Functor (DistributionT m) where
fmap f x = pure f <*> x
instance Monad m => Applicative (DistributionT m) where
pure x = DistributionT $ pure . (x, )
(<*>) = ap
instance Monad m => Monad (DistributionT m) where
a >>= b = DistributionT $ runDistT a >=> (\(a', s') -> runDistT (b a') s')
instance MonadTrans DistributionT where
lift m = DistributionT $ \s -> (, s) <$> m
class Sampler s
where
sampleCoin :: s -> (Bool, s)
sample :: s -> DistributionT m b -> m (b, s)
sample s d = runDistT d s
sample' :: Monad m => s -> DistributionT m b -> m b
sample' s d = fst <$> sample s d
instance Sampler SystemDRG where
sampleCoin s = (b .&. 1 == 1, s')
where
(ba :: B.Bytes, s') = randomBytesGenerate 1 s
(b, _) = fromJust $ B.uncons ba
liftDistribution :: Monad m => Distribution b -> DistributionT m b
liftDistribution d = DistributionT $ return . runIdentity . runDistT d
coin :: Distribution Bool
coin = DistributionT (Identity . sampleCoin)
uniform :: [a] -> Distribution a
uniform xs = do
let l = length xs
let lg = ilog2 l
n <- samplen lg
if n > l
then uniform xs
else return (xs !! n)
where
ilog2 :: Int -> Int
ilog2 1 = 0
ilog2 n
| n > 1 = ilog2 (n `div` 2) + 1
ilog2 _ = error "attempted non-postive logarithm"
samplen :: Int -> Distribution Int
samplen 0 = return 0
samplen lg
| lg > 0 = do
n' <- samplen (lg - 1)
c <- coin
return $
(n' * 2) +
if c
then 1
else 0
samplen _ = error "attempted to sample negative logarithm"