{-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} -- | Provides primitives for high-level cryptographic sampling. module Yggdrasil.Distribution ( Distribution , DistributionT(DistributionT, runDistT) , Sampler(..) , liftDistribution , coin , uniform ) where import Control.Monad (ap, (>=>)) import Control.Monad.Trans.Class (MonadTrans (lift)) import Crypto.Random (SystemDRG, randomBytesGenerate) import Data.Bits ((.&.)) import qualified Data.ByteArray as B import Data.Functor.Identity (Identity (Identity), runIdentity) import Data.Maybe (fromJust) -- | Allows randomly sampling elements of type @b@. type Distribution = DistributionT Identity -- | Allows randomly sampling elements of type @b@ in the context of monad @m@. newtype DistributionT m b = DistributionT { runDistT :: forall s. Sampler s => s -> m (b, s) } instance Monad m => Functor (DistributionT m) where fmap f x = pure f <*> x instance Monad m => Applicative (DistributionT m) where pure x = DistributionT $ pure . (x, ) (<*>) = ap instance Monad m => Monad (DistributionT m) where a >>= b = DistributionT $ runDistT a >=> (\(a', s') -> runDistT (b a') s') instance MonadTrans DistributionT where lift m = DistributionT $ \s -> (, s) <$> m -- | Provides randomness. class Sampler s where -- | Produce a bit of randomness. sampleCoin :: s -> (Bool, s) -- | Samples a distribution. sample :: s -> DistributionT m b -> m (b, s) sample s d = runDistT d s -- | Samples a distribution, discarding the result randomness. sample' :: Monad m => s -> DistributionT m b -> m b sample' s d = fst <$> sample s d instance Sampler SystemDRG where sampleCoin s = (b .&. 1 == 1, s') where (ba :: B.Bytes, s') = randomBytesGenerate 1 s -- fromJust is safe, as the array is not empty. (b, _) = fromJust $ B.uncons ba -- | Lifts a 'Distribution' to an arbitrary monadic 'DistributionT'. liftDistribution :: Monad m => Distribution b -> DistributionT m b liftDistribution d = DistributionT $ return . runIdentity . runDistT d -- | Tosses a fair coin. coin :: Distribution Bool coin = DistributionT (Identity . sampleCoin) -- | A uniform 'Distribution' over all elements of @[a]@. uniform :: [a] -> Distribution a uniform xs = do let l = length xs let lg = ilog2 l n <- samplen lg if n > l then uniform xs else return (xs !! n) where ilog2 :: Int -> Int ilog2 1 = 0 ilog2 n | n > 1 = ilog2 (n `div` 2) + 1 ilog2 _ = error "attempted non-postive logarithm" samplen :: Int -> Distribution Int samplen 0 = return 0 samplen lg | lg > 0 = do n' <- samplen (lg - 1) c <- coin return $ (n' * 2) + if c then 1 else 0 samplen _ = error "attempted to sample negative logarithm"