{-# language GeneralizedNewtypeDeriving #-}
{-# options_ghc -Wno-unused-imports #-}
{-|
Random samplers for few common distributions, with an interface similar to that of @mwc-probability@.

= Usage

Compose your random sampler out of simpler ones thanks to the Applicative and Monad interface, e.g. this is how you would declare and sample a binary mixture of Gaussian random variables:

@
import Control.Monad (replicateM)
import System.Random.SplitMix.Distributions (Gen, sample, bernoulli, normal)

process :: `Gen` Double
process = do
  coin <- `bernoulli` 0.7
  if coin
    then
      `normal` 0 2
    else
      normal 3 1

dataset :: [Double]
dataset = `sample` 1234 $ replicateM 20 process
@

and sample your data in a pure (`sample`) or monadic (`sampleT`) setting.

== Implementation details

The library is built on top of @splitmix@, so the caveats on safety and performance that apply there are relevant here as well.


-}
module System.Random.SplitMix.Distributions (
  -- * Distributions
  -- ** Continuous
  stdUniform, uniformR,
  exponential,
  stdNormal, normal,
  beta,
  gamma,
  pareto,
  dirichlet,
  -- ** Discrete
  bernoulli, fairCoin,
  multinomial,
  categorical,
  discrete,
  zipf,
  -- * PRNG
  -- ** Pure
  Gen, sample, samples,
  -- ** Monadic
  GenT, sampleT, samplesT,
  withGen
                                            ) where

import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Foldable (toList)
import Data.Functor.Identity (Identity(..))
import Data.List (findIndex)
import GHC.Word (Word64)

-- erf
import Data.Number.Erf (InvErf(..))
-- mtl
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.State (MonadState(..), modify)
-- splitmix
import System.Random.SplitMix (SMGen, mkSMGen, splitSMGen, nextInt, nextInteger, nextDouble)
-- transformers
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)

-- | Random generator
--
-- wraps 'splitmix' state-passing inside a 'StateT' monad
--
-- useful for embedding random generation inside a larger effect stack
newtype GenT m a = GenT { GenT m a -> StateT SMGen m a
unGen :: StateT SMGen m a } deriving (a -> GenT m b -> GenT m a
(a -> b) -> GenT m a -> GenT m b
(forall a b. (a -> b) -> GenT m a -> GenT m b)
-> (forall a b. a -> GenT m b -> GenT m a) -> Functor (GenT m)
forall a b. a -> GenT m b -> GenT m a
forall a b. (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> GenT m b -> GenT m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
fmap :: (a -> b) -> GenT m a -> GenT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
Functor, Functor (GenT m)
a -> GenT m a
Functor (GenT m)
-> (forall a. a -> GenT m a)
-> (forall a b. GenT m (a -> b) -> GenT m a -> GenT m b)
-> (forall a b c.
    (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m a)
-> Applicative (GenT m)
GenT m a -> GenT m b -> GenT m b
GenT m a -> GenT m b -> GenT m a
GenT m (a -> b) -> GenT m a -> GenT m b
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m (a -> b) -> GenT m a -> GenT m b
forall a b c. (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (m :: * -> *). Monad m => Functor (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: GenT m a -> GenT m b -> GenT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
*> :: GenT m a -> GenT m b -> GenT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
liftA2 :: (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
<*> :: GenT m (a -> b) -> GenT m a -> GenT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
pure :: a -> GenT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> GenT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (GenT m)
Applicative, Applicative (GenT m)
a -> GenT m a
Applicative (GenT m)
-> (forall a b. GenT m a -> (a -> GenT m b) -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a. a -> GenT m a)
-> Monad (GenT m)
GenT m a -> (a -> GenT m b) -> GenT m b
GenT m a -> GenT m b -> GenT m b
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *). Monad m => Applicative (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> GenT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> GenT m a
>> :: GenT m a -> GenT m b -> GenT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
>>= :: GenT m a -> (a -> GenT m b) -> GenT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (GenT m)
Monad, MonadState SMGen, m a -> GenT m a
(forall (m :: * -> *) a. Monad m => m a -> GenT m a)
-> MonadTrans GenT
forall (m :: * -> *) a. Monad m => m a -> GenT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> GenT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> GenT m a
MonadTrans, Monad (GenT m)
Monad (GenT m) -> (forall a. IO a -> GenT m a) -> MonadIO (GenT m)
IO a -> GenT m a
forall a. IO a -> GenT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (GenT m)
forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
liftIO :: IO a -> GenT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (GenT m)
MonadIO)

-- | Pure random generation
type Gen = GenT Identity

-- | Sample in a monadic context
sampleT :: Monad m =>
           Word64 -- ^ random seed
        -> GenT m a
        -> m a
sampleT :: Word64 -> GenT m a -> m a
sampleT Word64
seed GenT m a
gg = StateT SMGen m a -> SMGen -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (GenT m a -> StateT SMGen m a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen GenT m a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)

-- | Sample a batch
samplesT :: Monad m =>
            Int -- ^ size of sample
         -> Word64 -- ^ random seed
         -> GenT m a
         -> m [a]
samplesT :: Int -> Word64 -> GenT m a -> m [a]
samplesT Int
n Word64
seed GenT m a
gg = Word64 -> GenT m [a] -> m [a]
forall (m :: * -> *) a. Monad m => Word64 -> GenT m a -> m a
sampleT Word64
seed (Int -> GenT m a -> GenT m [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n GenT m a
gg)

-- | Pure sampling
sample :: Word64 -- ^ random seed
        -> Gen a
        -> a
sample :: Word64 -> Gen a -> a
sample Word64
seed Gen a
gg = State SMGen a -> SMGen -> a
forall s a. State s a -> s -> a
evalState (Gen a -> State SMGen a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen Gen a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)

-- | Sample a batch
samples :: Int -- ^ sample size
        -> Word64 -- ^ random seed
        -> Gen a
        -> [a]
samples :: Int -> Word64 -> Gen a -> [a]
samples Int
n Word64
seed Gen a
gg = Word64 -> Gen [a] -> [a]
forall a. Word64 -> Gen a -> a
sample Word64
seed (Int -> Gen a -> Gen [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n Gen a
gg)

-- | Bernoulli trial
bernoulli :: Monad m =>
             Double -- ^ bias parameter \( 0 \lt p \lt 1 \)
          -> GenT m Bool
bernoulli :: Double -> GenT m Bool
bernoulli Double
p = (SMGen -> (Bool, SMGen)) -> GenT m Bool
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p)

-- | A fair coin toss returns either value with probability 0.5
fairCoin :: Monad m => GenT m Bool
fairCoin :: GenT m Bool
fairCoin = Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
0.5

-- | Multinomial distribution
--
-- NB : returns @Nothing@ if any of the input probabilities is negative
multinomial :: (Monad m, Foldable t) =>
               Int -- ^ number of Bernoulli trials \( n \gt 0 \)
            -> t Double -- ^ probability vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
            -> GenT m (Maybe [Int])
multinomial :: Int -> t Double -> GenT m (Maybe [Int])
multinomial Int
n t Double
ps = do
    let ([Double]
cumulative, Double
total) = [Double] -> ([Double], Double)
forall a. Num a => [a] -> ([a], a)
runningTotals (t Double -> [Double]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t Double
ps)
    [Maybe Int]
ms <- Int -> GenT m (Maybe Int) -> GenT m [Maybe Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (GenT m (Maybe Int) -> GenT m [Maybe Int])
-> GenT m (Maybe Int) -> GenT m [Maybe Int]
forall a b. (a -> b) -> a -> b
$ do
      Double
z <- Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
uniformR Double
0 Double
total
      Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> GenT m (Maybe Int))
-> Maybe Int -> GenT m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
z) [Double]
cumulative
        -- Just g  -> return g
        -- Nothing -> error "splitmix-distributions: invalid probability vector"
    Maybe [Int] -> GenT m (Maybe [Int])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [Int] -> GenT m (Maybe [Int]))
-> Maybe [Int] -> GenT m (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> Maybe [Int]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe Int]
ms
  where
    runningTotals :: Num a => [a] -> ([a], a)
    runningTotals :: [a] -> ([a], a)
runningTotals [a]
xs = let adds :: [a]
adds = (a -> a -> a) -> [a] -> [a]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs in ([a]
adds, [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
xs)
{-# INLINABLE multinomial #-}


-- | Categorical distribution
--
-- Picks one index out of a discrete set with probability proportional to those supplied as input parameter vector
categorical :: (Monad m, Foldable t) =>
               t Double -- ^ probability vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
            -> GenT m (Maybe Int)
categorical :: t Double -> GenT m (Maybe Int)
categorical t Double
ps = do
  Maybe [Int]
xs <- Int -> t Double -> GenT m (Maybe [Int])
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
Int -> t Double -> GenT m (Maybe [Int])
multinomial Int
1 t Double
ps
  case Maybe [Int]
xs of
    Just [Int
x] -> Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> GenT m (Maybe Int))
-> Maybe Int -> GenT m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
x
    Maybe [Int]
_ -> Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing


-- | The Zipf-Mandelbrot distribution.
--
--  Note that values of the parameter close to 1 are very computationally intensive.
--
--  >>> samples 10 1234 (zipf 1.1)
--  [3170051793,2,668775891,146169301649651,23,36,5,6586194257347,21,37911]
--
--  >>> samples 10 1234 (zipf 1.5)
--  [79,1,58,680,3,1,2,1,366,1]
zipf :: (Monad m, Integral i) =>
        Double -- ^ \( \alpha \gt 1 \)
     -> GenT m i
zipf :: Double -> GenT m i
zipf Double
a = do
  let
    b :: Double
b = Double
2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
    go :: GenT m i
go = do
        Double
u <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
        Double
v <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
        let xInt :: i
xInt = Double -> i
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double
u Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (- Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)))
            x :: Double
x = i -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral i
xInt
            t :: Double
t = (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
x) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
        if Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
t Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
b
          then i -> GenT m i
forall (m :: * -> *) a. Monad m => a -> m a
return i
xInt
          else GenT m i
go
  GenT m i
go
{-# INLINABLE zipf #-}

-- | Discrete distribution
--
-- Pick one item with probability proportional to those supplied as input parameter vector
discrete :: (Monad m, Foldable t) =>
            t (Double, b) -- ^ (probability, item) vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
         -> GenT m (Maybe b)
discrete :: t (Double, b) -> GenT m (Maybe b)
discrete t (Double, b)
d = do
  let ([Double]
ps, [b]
xs) = [(Double, b)] -> ([Double], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip (t (Double, b) -> [(Double, b)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t (Double, b)
d)
  Maybe Int
midx <- [Double] -> GenT m (Maybe Int)
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t Double -> GenT m (Maybe Int)
categorical [Double]
ps
  Maybe b -> GenT m (Maybe b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe b -> GenT m (Maybe b)) -> Maybe b -> GenT m (Maybe b)
forall a b. (a -> b) -> a -> b
$ ([b]
xs [b] -> Int -> b
forall a. [a] -> Int -> a
!!) (Int -> b) -> Maybe Int -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Int
midx

-- | Uniform between two values
uniformR :: Monad m =>
            Double -- ^ low
         -> Double -- ^ high
         -> GenT m Double
uniformR :: Double -> Double -> GenT m Double
uniformR Double
lo Double
hi = Double -> Double
scale (Double -> Double) -> GenT m Double -> GenT m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  where
    scale :: Double -> Double
scale Double
x = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lo

-- | Standard normal distribution
stdNormal :: Monad m => GenT m Double
stdNormal :: GenT m Double
stdNormal = Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
0 Double
1

-- | Uniform in [0, 1)
stdUniform :: Monad m => GenT m Double
stdUniform :: GenT m Double
stdUniform = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (Double, SMGen)
nextDouble

-- | Beta distribution, from two standard uniform samples
beta :: Monad m =>
        Double -- ^ shape parameter \( \alpha \gt 0 \) 
     -> Double -- ^ shape parameter \( \beta \gt 0 \)
     -> GenT m Double
beta :: Double -> Double -> GenT m Double
beta Double
a Double
b = GenT m Double
go
  where
    go :: GenT m Double
go = do
      (Double
y1, Double
y2) <- GenT m (Double, Double)
sample2
      if
        Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
1
        then Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
y1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2))
        else GenT m Double
go
    sample2 :: GenT m (Double, Double)
sample2 = Double -> Double -> (Double, Double)
f (Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> (Double, Double))
-> GenT m Double -> GenT m (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
      where
        f :: Double -> Double -> (Double, Double)
f Double
u1 Double
u2 = (Double
u1 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
a), Double
u2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
b))

-- | Gamma distribution, using Ahrens-Dieter accept-reject (algorithm GD):
--
-- Ahrens, J. H.; Dieter, U (January 1982). "Generating gamma variates by a modified rejection technique". Communications of the ACM. 25 (1): 47–54
gamma :: Monad m =>
         Double -- ^ shape parameter \( k \gt 0 \)
      -> Double -- ^ scale parameter \( \theta \gt 0 \)
      -> GenT m Double
gamma :: Double -> Double -> GenT m Double
gamma Double
k Double
th = do
  Double
xi <- GenT m Double
sampleXi
  [Double]
us <- Int -> GenT m Double -> GenT m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> GenT m Double -> GenT m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform)
  Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double
th Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
us
  where
    sampleXi :: GenT m Double
sampleXi = do
      (Double
xi, Double
eta) <- GenT m (Double, Double)
sample2
      if Double
eta Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi)
        then GenT m Double
sampleXi
        else Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
xi
    (Int
n, Double
delta) = (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
k, Double
k Double -> Double -> Double
forall a. Num a => a -> a -> a
- Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    ee :: Double
ee = Double -> Double
forall a. Floating a => a -> a
exp Double
1
    sample2 :: GenT m (Double, Double)
sample2 = Double -> Double -> Double -> (Double, Double)
f (Double -> Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> (Double, Double))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> (Double, Double))
-> GenT m Double -> GenT m (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
      where
        f :: Double -> Double -> Double -> (Double, Double)
f Double
u Double
v Double
w
          | Double
u Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
ee Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
ee Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
delta) =
            let xi :: Double
xi = Double
v Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
delta)
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1))
          | Bool
otherwise =
            let xi :: Double
xi = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log Double
v
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi))

-- | Pareto distribution
pareto :: Monad m =>
          Double -- ^ shape parameter \( \alpha \gt 0 \)
       -> Double -- ^ scale parameter \( x_{min} \gt 0 \)
       -> GenT m Double
pareto :: Double -> Double -> GenT m Double
pareto Double
a Double
xmin = do
  Double
y <- Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> GenT m Double
exponential Double
a
  Double -> GenT m Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double
xmin Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp Double
y
{-# INLINABLE pareto #-}

-- | The Dirichlet distribution with the provided concentration parameters.
--   The dimension of the distribution is determined by the number of
--   concentration parameters supplied.
--
--   >>> sample 1234 (dirichlet [0.1, 1, 10])
--   [2.3781130220132788e-11,6.646079701567026e-2,0.9335392029605486]
dirichlet :: (Monad m, Traversable f) =>
             f Double -- ^ concentration parameters \( \gamma_i \gt 0 , \forall i \)
          -> GenT m (f Double)
dirichlet :: f Double -> GenT m (f Double)
dirichlet f Double
as = do
  f Double
zs <- (Double -> GenT m Double) -> f Double -> GenT m (f Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
`gamma` Double
1) f Double
as
  f Double -> GenT m (f Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (f Double -> GenT m (f Double)) -> f Double -> GenT m (f Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> f Double -> f Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ f Double -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum f Double
zs) f Double
zs
{-# INLINABLE dirichlet #-}


-- | Normal distribution
normal :: Monad m =>
          Double -- ^ mean
       -> Double -- ^ standard deviation \( \sigma \gt 0 \)
       -> GenT m Double
normal :: Double -> Double -> GenT m Double
normal Double
mu Double
sig = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig)

-- | Exponential distribution
exponential :: Monad m =>
               Double -- ^ rate parameter \( \lambda > 0 \)
            -> GenT m Double
exponential :: Double -> GenT m Double
exponential Double
l = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Double, SMGen)
exponentialF Double
l)

-- | Wrap a 'splitmix' PRNG function
withGen :: Monad m =>
           (SMGen -> (a, SMGen)) -- ^ explicit generator passing (e.g. 'nextDouble')
        -> GenT m a
withGen :: (SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (a, SMGen)
f = StateT SMGen m a -> GenT m a
forall (m :: * -> *) a. StateT SMGen m a -> GenT m a
GenT (StateT SMGen m a -> GenT m a) -> StateT SMGen m a -> GenT m a
forall a b. (a -> b) -> a -> b
$ do
  SMGen
gen <- StateT SMGen m SMGen
forall s (m :: * -> *). MonadState s m => m s
get
  let
    (a
b, SMGen
gen') = SMGen -> (a, SMGen)
f SMGen
gen
  SMGen -> StateT SMGen m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put SMGen
gen'
  a -> StateT SMGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b

exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF Double
l SMGen
g = (Double -> Double -> Double
forall a. Floating a => a -> a -> a
exponentialICDF Double
l Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig SMGen
g = (Double -> Double -> Double -> Double
forall a. InvErf a => a -> a -> a -> a
normalICDF Double
mu Double
sig Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p SMGen
g = (Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p , SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g


-- | inverse CDF of normal rv
normalICDF :: InvErf a =>
              a -- ^ mean
           -> a -- ^ std dev
           -> a -> a
normalICDF :: a -> a -> a -> a
normalICDF a
mu a
sig a
p = a
mu a -> a -> a
forall a. Num a => a -> a -> a
+ a
sig a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. InvErf a => a -> a
inverf (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

-- | inverse CDF of exponential rv
exponentialICDF :: Floating a =>
                   a -- ^ rate
                -> a -> a
exponentialICDF :: a -> a -> a
exponentialICDF a
l a
p = (- a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
l) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
p)