{-# LANGUAGE FlexibleInstances #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# language MultiParamTypeClasses #-}
{-# language UndecidableInstances #-}
{-# options_ghc -Wno-unused-imports #-}
{-|
Random samplers for few common distributions, with an interface similar to that of @mwc-probability@.

= Usage

Compose your random sampler out of simpler ones thanks to the Applicative and Monad interface, e.g. this is how you would declare and sample a binary mixture of Gaussian random variables:

@
import Control.Monad (replicateM)
import System.Random.SplitMix.Distributions (Gen, sample, bernoulli, normal)

process :: `Gen` Double
process = do
  coin <- `bernoulli` 0.7
  if coin
    then
      `normal` 0 2
    else
      normal 3 1

dataset :: [Double]
dataset = `sample` 1234 $ replicateM 20 process
@

and sample your data in a pure (`sample`) or monadic (`sampleT`) setting.

Initializing the PRNG with a fixed seed makes all results fully reproducible across runs. If this behavior is not desired, one can sample the random seed itself from an IO-based entropy pool, and run the samplers with `sampleIO` and `samplesIO`.

== Implementation details

The library is built on top of @splitmix@ ( https://hackage.haskell.org/package/splitmix ), which provides fast pseudorandom number generation utilities.


-}
module System.Random.SplitMix.Distributions (
  -- * Distributions
  -- ** Continuous
  stdUniform, uniformR,
  exponential,
  stdNormal, normal,
  beta,
  gamma,
  pareto,
  dirichlet,
  logNormal,
  laplace,
  weibull,
  -- ** Discrete
  bernoulli, fairCoin,
  multinomial,
  categorical,
  discrete,
  zipf,
  crp,
  -- * PRNG
  -- ** Pure
  Gen, sample, samples,
  -- ** Monadic
  GenT, sampleT, samplesT,
  -- ** IO-based
  sampleIO, samplesIO,
  -- ** splitmix utilities
  withGen
                                            ) where

import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Foldable (toList)
import Data.Functor.Identity (Identity(..))
import Data.List (findIndex)
import Data.Monoid (Sum(..))
import GHC.Word (Word64)

-- containers
import qualified Data.IntMap as IM
-- erf
import Data.Number.Erf (InvErf(..))
-- exceptions
import Control.Monad.Catch (MonadThrow(..))
-- mtl
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Reader (MonadReader(..))
import Control.Monad.State (MonadState(..), modify)
-- splitmix
import System.Random.SplitMix (SMGen, mkSMGen, initSMGen, unseedSMGen, splitSMGen, nextDouble)
-- transformers
import Control.Monad.Trans.Reader (ReaderT(..))
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)

-- | Random generator
--
-- wraps 'splitmix' state-passing inside a 'StateT' monad
--
-- useful for embedding random generation inside a larger effect stack
newtype GenT m a = GenT { GenT m a -> StateT SMGen m a
unGen :: StateT SMGen m a } deriving (a -> GenT m b -> GenT m a
(a -> b) -> GenT m a -> GenT m b
(forall a b. (a -> b) -> GenT m a -> GenT m b)
-> (forall a b. a -> GenT m b -> GenT m a) -> Functor (GenT m)
forall a b. a -> GenT m b -> GenT m a
forall a b. (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> GenT m b -> GenT m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
fmap :: (a -> b) -> GenT m a -> GenT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
Functor, Functor (GenT m)
a -> GenT m a
Functor (GenT m)
-> (forall a. a -> GenT m a)
-> (forall a b. GenT m (a -> b) -> GenT m a -> GenT m b)
-> (forall a b c.
    (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m a)
-> Applicative (GenT m)
GenT m a -> GenT m b -> GenT m b
GenT m a -> GenT m b -> GenT m a
GenT m (a -> b) -> GenT m a -> GenT m b
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m (a -> b) -> GenT m a -> GenT m b
forall a b c. (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (m :: * -> *). Monad m => Functor (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: GenT m a -> GenT m b -> GenT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
*> :: GenT m a -> GenT m b -> GenT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
liftA2 :: (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
<*> :: GenT m (a -> b) -> GenT m a -> GenT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
pure :: a -> GenT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> GenT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (GenT m)
Applicative, Applicative (GenT m)
a -> GenT m a
Applicative (GenT m)
-> (forall a b. GenT m a -> (a -> GenT m b) -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a. a -> GenT m a)
-> Monad (GenT m)
GenT m a -> (a -> GenT m b) -> GenT m b
GenT m a -> GenT m b -> GenT m b
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *). Monad m => Applicative (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> GenT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> GenT m a
>> :: GenT m a -> GenT m b -> GenT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
>>= :: GenT m a -> (a -> GenT m b) -> GenT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (GenT m)
Monad, m a -> GenT m a
(forall (m :: * -> *) a. Monad m => m a -> GenT m a)
-> MonadTrans GenT
forall (m :: * -> *) a. Monad m => m a -> GenT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> GenT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> GenT m a
MonadTrans, Monad (GenT m)
Monad (GenT m) -> (forall a. IO a -> GenT m a) -> MonadIO (GenT m)
IO a -> GenT m a
forall a. IO a -> GenT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (GenT m)
forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
liftIO :: IO a -> GenT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (GenT m)
MonadIO, Monad (GenT m)
e -> GenT m a
Monad (GenT m)
-> (forall e a. Exception e => e -> GenT m a)
-> MonadThrow (GenT m)
forall e a. Exception e => e -> GenT m a
forall (m :: * -> *).
Monad m -> (forall e a. Exception e => e -> m a) -> MonadThrow m
forall (m :: * -> *). MonadThrow m => Monad (GenT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> GenT m a
throwM :: e -> GenT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> GenT m a
$cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (GenT m)
MonadThrow, MonadReader r)

instance MonadState s m => MonadState s (GenT m) where
  get :: GenT m s
get = m s -> GenT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> GenT m ()
put = m () -> GenT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> GenT m ()) -> (s -> m ()) -> s -> GenT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
  state :: (s -> (a, s)) -> GenT m a
state = m a -> GenT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> GenT m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> GenT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state


-- | Pure random generation
type Gen = GenT Identity

-- | Sample in a monadic context
sampleT :: Monad m =>
           Word64 -- ^ random seed
        -> GenT m a
        -> m a
sampleT :: Word64 -> GenT m a -> m a
sampleT Word64
seed GenT m a
gg = StateT SMGen m a -> SMGen -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (GenT m a -> StateT SMGen m a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen GenT m a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)

-- | Initialize a splitmix random generator from system entropy (current time etc.) and sample from the PRNG.
sampleIO :: MonadIO m => GenT m b -> m b
sampleIO :: GenT m b -> m b
sampleIO GenT m b
gg = do
  (Word64
s, Word64
_) <- SMGen -> (Word64, Word64)
unseedSMGen (SMGen -> (Word64, Word64)) -> m SMGen -> m (Word64, Word64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO SMGen -> m SMGen
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO SMGen
initSMGen
  Word64 -> GenT m b -> m b
forall (m :: * -> *) a. Monad m => Word64 -> GenT m a -> m a
sampleT Word64
s GenT m b
gg

-- | Sample a batch
samplesT :: Monad m =>
            Int -- ^ size of sample
         -> Word64 -- ^ random seed
         -> GenT m a
         -> m [a]
samplesT :: Int -> Word64 -> GenT m a -> m [a]
samplesT Int
n Word64
seed GenT m a
gg = Word64 -> GenT m [a] -> m [a]
forall (m :: * -> *) a. Monad m => Word64 -> GenT m a -> m a
sampleT Word64
seed (Int -> GenT m a -> GenT m [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n GenT m a
gg)

-- | Initialize a splitmix random generator from system entropy (current time etc.) and sample n times from the PRNG.
samplesIO :: MonadIO m => Int -> GenT m a -> m [a]
samplesIO :: Int -> GenT m a -> m [a]
samplesIO Int
n GenT m a
gg = do
  (Word64
s, Word64
_) <- SMGen -> (Word64, Word64)
unseedSMGen (SMGen -> (Word64, Word64)) -> m SMGen -> m (Word64, Word64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO SMGen -> m SMGen
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO SMGen
initSMGen
  Int -> Word64 -> GenT m a -> m [a]
forall (m :: * -> *) a.
Monad m =>
Int -> Word64 -> GenT m a -> m [a]
samplesT Int
n Word64
s GenT m a
gg

-- | Pure sampling
sample :: Word64 -- ^ random seed
        -> Gen a
        -> a
sample :: Word64 -> Gen a -> a
sample Word64
seed Gen a
gg = State SMGen a -> SMGen -> a
forall s a. State s a -> s -> a
evalState (Gen a -> State SMGen a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen Gen a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)

-- | Sample a batch
samples :: Int -- ^ sample size
        -> Word64 -- ^ random seed
        -> Gen a
        -> [a]
samples :: Int -> Word64 -> Gen a -> [a]
samples Int
n Word64
seed Gen a
gg = Word64 -> Gen [a] -> [a]
forall a. Word64 -> Gen a -> a
sample Word64
seed (Int -> Gen a -> Gen [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n Gen a
gg)

-- | Bernoulli trial
bernoulli :: Monad m =>
             Double -- ^ bias parameter \( 0 \lt p \lt 1 \)
          -> GenT m Bool
bernoulli :: Double -> GenT m Bool
bernoulli Double
p = (SMGen -> (Bool, SMGen)) -> GenT m Bool
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p)

-- | A fair coin toss returns either value with probability 0.5
fairCoin :: Monad m => GenT m Bool
fairCoin :: GenT m Bool
fairCoin = Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
0.5

-- | Multinomial distribution
--
-- NB : returns @Nothing@ if any of the input probabilities is negative
multinomial :: (Monad m, Foldable t) =>
               Int -- ^ number of Bernoulli trials \( n \gt 0 \)
            -> t Double -- ^ probability vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
            -> GenT m (Maybe [Int])
multinomial :: Int -> t Double -> GenT m (Maybe [Int])
multinomial Int
n t Double
ps = do
    let ([Double]
cumulative, Double
total) = [Double] -> ([Double], Double)
forall a. Num a => [a] -> ([a], a)
runningTotals (t Double -> [Double]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t Double
ps)
    [Maybe Int]
ms <- Int -> GenT m (Maybe Int) -> GenT m [Maybe Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (GenT m (Maybe Int) -> GenT m [Maybe Int])
-> GenT m (Maybe Int) -> GenT m [Maybe Int]
forall a b. (a -> b) -> a -> b
$ do
      Double
z <- Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
uniformR Double
0 Double
total
      Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> GenT m (Maybe Int))
-> Maybe Int -> GenT m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
z) [Double]
cumulative
        -- Just g  -> return g
        -- Nothing -> error "splitmix-distributions: invalid probability vector"
    Maybe [Int] -> GenT m (Maybe [Int])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [Int] -> GenT m (Maybe [Int]))
-> Maybe [Int] -> GenT m (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> Maybe [Int]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe Int]
ms
  where
    runningTotals :: Num a => [a] -> ([a], a)
    runningTotals :: [a] -> ([a], a)
runningTotals [a]
xs = let adds :: [a]
adds = (a -> a -> a) -> [a] -> [a]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs in ([a]
adds, [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
xs)
{-# INLINABLE multinomial #-}


-- | Categorical distribution
--
-- Picks one index out of a discrete set with probability proportional to those supplied as input parameter vector
categorical :: (Monad m, Foldable t) =>
               t Double -- ^ probability vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
            -> GenT m (Maybe Int)
categorical :: t Double -> GenT m (Maybe Int)
categorical t Double
ps = do
  Maybe [Int]
xs <- Int -> t Double -> GenT m (Maybe [Int])
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
Int -> t Double -> GenT m (Maybe [Int])
multinomial Int
1 t Double
ps
  case Maybe [Int]
xs of
    Just [Int
x] -> Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> GenT m (Maybe Int))
-> Maybe Int -> GenT m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
x
    Maybe [Int]
_ -> Maybe Int -> GenT m (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing


-- | The Zipf-Mandelbrot distribution.
--
--  Note that values of the parameter close to 1 are very computationally intensive.
--
--  >>> samples 10 1234 (zipf 1.1)
--  [3170051793,2,668775891,146169301649651,23,36,5,6586194257347,21,37911]
--
--  >>> samples 10 1234 (zipf 1.5)
--  [79,1,58,680,3,1,2,1,366,1]
zipf :: (Monad m, Integral i) =>
        Double -- ^ \( \alpha \gt 1 \)
     -> GenT m i
zipf :: Double -> GenT m i
zipf Double
a = do
  let
    b :: Double
b = Double
2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
    go :: GenT m i
go = do
        Double
u <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
        Double
v <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
        let xInt :: i
xInt = Double -> i
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double
u Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (- Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)))
            x :: Double
x = i -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral i
xInt
            t :: Double
t = (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
x) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
        if Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
t Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
b
          then i -> GenT m i
forall (m :: * -> *) a. Monad m => a -> m a
return i
xInt
          else GenT m i
go
  GenT m i
go
{-# INLINABLE zipf #-}

-- | Discrete distribution
--
-- Pick one item with probability proportional to those supplied as input parameter vector
discrete :: (Monad m, Foldable t) =>
            t (Double, b) -- ^ (probability, item) vector \( p_i \gt 0 , \forall i \) (does not need to be normalized)
         -> GenT m (Maybe b)
discrete :: t (Double, b) -> GenT m (Maybe b)
discrete t (Double, b)
d = do
  let ([Double]
ps, [b]
xs) = [(Double, b)] -> ([Double], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip (t (Double, b) -> [(Double, b)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t (Double, b)
d)
  Maybe Int
midx <- [Double] -> GenT m (Maybe Int)
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t Double -> GenT m (Maybe Int)
categorical [Double]
ps
  Maybe b -> GenT m (Maybe b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe b -> GenT m (Maybe b)) -> Maybe b -> GenT m (Maybe b)
forall a b. (a -> b) -> a -> b
$ ([b]
xs [b] -> Int -> b
forall a. [a] -> Int -> a
!!) (Int -> b) -> Maybe Int -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Int
midx


-- | Chinese restaurant process
--
-- >>> sample 1234 $ crp 1.02 50
-- [24,18,7,1]
--
-- >>> sample 1234 $ crp 2 50
-- [17,8,13,3,3,3,2,1]
--
-- >>> sample 1234 $ crp 10 50
-- [5,7,1,6,1,3,5,1,1,3,1,1,1,4,3,1,3,1,1,1]
crp :: Monad m =>
       Double -- ^ concentration parameter \( \alpha \gt 1 \)
    -> Int -- ^ number of customers \( n > 0 \)
    -> GenT m [Integer]
crp :: Double -> Int -> GenT m [Integer]
crp Double
a Int
n = do
    CRPTables (Sum Integer)
ts <- CRPTables (Sum Integer) -> Int -> GenT m (CRPTables (Sum Integer))
forall (m :: * -> *) a.
(Monad m, Integral a) =>
CRPTables (Sum a) -> Int -> GenT m (CRPTables (Sum a))
go CRPTables (Sum Integer)
crpInitial Int
1
    [Integer] -> GenT m [Integer]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Integer] -> GenT m [Integer]) -> [Integer] -> GenT m [Integer]
forall a b. (a -> b) -> a -> b
$ CRPTables Integer -> [Integer]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ((Sum Integer -> Integer)
-> CRPTables (Sum Integer) -> CRPTables Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Sum Integer -> Integer
forall a. Sum a -> a
getSum CRPTables (Sum Integer)
ts)
  where
    go :: CRPTables (Sum a) -> Int -> GenT m (CRPTables (Sum a))
go CRPTables (Sum a)
acc Int
i
      | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = CRPTables (Sum a) -> GenT m (CRPTables (Sum a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure CRPTables (Sum a)
acc
      | Bool
otherwise = do
          CRPTables (Sum a)
acc' <- Int -> CRPTables (Sum a) -> Double -> GenT m (CRPTables (Sum a))
forall (m :: * -> *) a.
(Monad m, Integral a) =>
Int -> CRPTables (Sum a) -> Double -> GenT m (CRPTables (Sum a))
crpSingle Int
i CRPTables (Sum a)
acc Double
a
          CRPTables (Sum a) -> Int -> GenT m (CRPTables (Sum a))
go CRPTables (Sum a)
acc' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINABLE crp #-}

-- | Update step of the CRP
crpSingle :: (Monad m, Integral a) =>
             Int -> CRPTables (Sum a) -> Double -> GenT m (CRPTables (Sum a))
crpSingle :: Int -> CRPTables (Sum a) -> Double -> GenT m (CRPTables (Sum a))
crpSingle Int
i CRPTables (Sum a)
zs Double
a = do
    Maybe Int
znm1 <- [Double] -> GenT m (Maybe Int)
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t Double -> GenT m (Maybe Int)
categorical [Double]
probs
    case Maybe Int
znm1 of
      Just Int
zn1 -> CRPTables (Sum a) -> GenT m (CRPTables (Sum a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CRPTables (Sum a) -> GenT m (CRPTables (Sum a)))
-> CRPTables (Sum a) -> GenT m (CRPTables (Sum a))
forall a b. (a -> b) -> a -> b
$ Int -> CRPTables (Sum a) -> CRPTables (Sum a)
forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
zn1 CRPTables (Sum a)
zs
      Maybe Int
_ -> CRPTables (Sum a) -> GenT m (CRPTables (Sum a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure CRPTables (Sum a)
forall a. Monoid a => a
mempty
  where
    probs :: [Double]
probs = [Double]
pms [Double] -> [Double] -> [Double]
forall a. Semigroup a => a -> a -> a
<> [Double
pm1]
    acc :: a -> Double
acc a
m = a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
m Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a)
    pms :: [Double]
pms = CRPTables Double -> [Double]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (CRPTables Double -> [Double]) -> CRPTables Double -> [Double]
forall a b. (a -> b) -> a -> b
$ (Sum a -> Double) -> CRPTables (Sum a) -> CRPTables Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Double
forall a. Integral a => a -> Double
acc (a -> Double) -> (Sum a -> a) -> Sum a -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sum a -> a
forall a. Sum a -> a
getSum) CRPTables (Sum a)
zs
    pm1 :: Double
pm1 = Double
a Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a)

-- Tables at the Chinese Restaurant
newtype CRPTables c = CRP {
    CRPTables c -> IntMap c
getCRPTables :: IM.IntMap c
  } deriving (CRPTables c -> CRPTables c -> Bool
(CRPTables c -> CRPTables c -> Bool)
-> (CRPTables c -> CRPTables c -> Bool) -> Eq (CRPTables c)
forall c. Eq c => CRPTables c -> CRPTables c -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CRPTables c -> CRPTables c -> Bool
$c/= :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
== :: CRPTables c -> CRPTables c -> Bool
$c== :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
Eq, Int -> CRPTables c -> ShowS
[CRPTables c] -> ShowS
CRPTables c -> String
(Int -> CRPTables c -> ShowS)
-> (CRPTables c -> String)
-> ([CRPTables c] -> ShowS)
-> Show (CRPTables c)
forall c. Show c => Int -> CRPTables c -> ShowS
forall c. Show c => [CRPTables c] -> ShowS
forall c. Show c => CRPTables c -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CRPTables c] -> ShowS
$cshowList :: forall c. Show c => [CRPTables c] -> ShowS
show :: CRPTables c -> String
$cshow :: forall c. Show c => CRPTables c -> String
showsPrec :: Int -> CRPTables c -> ShowS
$cshowsPrec :: forall c. Show c => Int -> CRPTables c -> ShowS
Show, a -> CRPTables b -> CRPTables a
(a -> b) -> CRPTables a -> CRPTables b
(forall a b. (a -> b) -> CRPTables a -> CRPTables b)
-> (forall a b. a -> CRPTables b -> CRPTables a)
-> Functor CRPTables
forall a b. a -> CRPTables b -> CRPTables a
forall a b. (a -> b) -> CRPTables a -> CRPTables b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> CRPTables b -> CRPTables a
$c<$ :: forall a b. a -> CRPTables b -> CRPTables a
fmap :: (a -> b) -> CRPTables a -> CRPTables b
$cfmap :: forall a b. (a -> b) -> CRPTables a -> CRPTables b
Functor, a -> CRPTables a -> Bool
CRPTables m -> m
CRPTables a -> [a]
CRPTables a -> Bool
CRPTables a -> Int
CRPTables a -> a
CRPTables a -> a
CRPTables a -> a
CRPTables a -> a
(a -> m) -> CRPTables a -> m
(a -> m) -> CRPTables a -> m
(a -> b -> b) -> b -> CRPTables a -> b
(a -> b -> b) -> b -> CRPTables a -> b
(b -> a -> b) -> b -> CRPTables a -> b
(b -> a -> b) -> b -> CRPTables a -> b
(a -> a -> a) -> CRPTables a -> a
(a -> a -> a) -> CRPTables a -> a
(forall m. Monoid m => CRPTables m -> m)
-> (forall m a. Monoid m => (a -> m) -> CRPTables a -> m)
-> (forall m a. Monoid m => (a -> m) -> CRPTables a -> m)
-> (forall a b. (a -> b -> b) -> b -> CRPTables a -> b)
-> (forall a b. (a -> b -> b) -> b -> CRPTables a -> b)
-> (forall b a. (b -> a -> b) -> b -> CRPTables a -> b)
-> (forall b a. (b -> a -> b) -> b -> CRPTables a -> b)
-> (forall a. (a -> a -> a) -> CRPTables a -> a)
-> (forall a. (a -> a -> a) -> CRPTables a -> a)
-> (forall a. CRPTables a -> [a])
-> (forall a. CRPTables a -> Bool)
-> (forall a. CRPTables a -> Int)
-> (forall a. Eq a => a -> CRPTables a -> Bool)
-> (forall a. Ord a => CRPTables a -> a)
-> (forall a. Ord a => CRPTables a -> a)
-> (forall a. Num a => CRPTables a -> a)
-> (forall a. Num a => CRPTables a -> a)
-> Foldable CRPTables
forall a. Eq a => a -> CRPTables a -> Bool
forall a. Num a => CRPTables a -> a
forall a. Ord a => CRPTables a -> a
forall m. Monoid m => CRPTables m -> m
forall a. CRPTables a -> Bool
forall a. CRPTables a -> Int
forall a. CRPTables a -> [a]
forall a. (a -> a -> a) -> CRPTables a -> a
forall m a. Monoid m => (a -> m) -> CRPTables a -> m
forall b a. (b -> a -> b) -> b -> CRPTables a -> b
forall a b. (a -> b -> b) -> b -> CRPTables a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: CRPTables a -> a
$cproduct :: forall a. Num a => CRPTables a -> a
sum :: CRPTables a -> a
$csum :: forall a. Num a => CRPTables a -> a
minimum :: CRPTables a -> a
$cminimum :: forall a. Ord a => CRPTables a -> a
maximum :: CRPTables a -> a
$cmaximum :: forall a. Ord a => CRPTables a -> a
elem :: a -> CRPTables a -> Bool
$celem :: forall a. Eq a => a -> CRPTables a -> Bool
length :: CRPTables a -> Int
$clength :: forall a. CRPTables a -> Int
null :: CRPTables a -> Bool
$cnull :: forall a. CRPTables a -> Bool
toList :: CRPTables a -> [a]
$ctoList :: forall a. CRPTables a -> [a]
foldl1 :: (a -> a -> a) -> CRPTables a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldr1 :: (a -> a -> a) -> CRPTables a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldl' :: (b -> a -> b) -> b -> CRPTables a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldl :: (b -> a -> b) -> b -> CRPTables a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldr' :: (a -> b -> b) -> b -> CRPTables a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldr :: (a -> b -> b) -> b -> CRPTables a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldMap' :: (a -> m) -> CRPTables a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
foldMap :: (a -> m) -> CRPTables a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
fold :: CRPTables m -> m
$cfold :: forall m. Monoid m => CRPTables m -> m
Foldable, b -> CRPTables c -> CRPTables c
NonEmpty (CRPTables c) -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
(CRPTables c -> CRPTables c -> CRPTables c)
-> (NonEmpty (CRPTables c) -> CRPTables c)
-> (forall b. Integral b => b -> CRPTables c -> CRPTables c)
-> Semigroup (CRPTables c)
forall b. Integral b => b -> CRPTables c -> CRPTables c
forall c. NonEmpty (CRPTables c) -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall c b. Integral b => b -> CRPTables c -> CRPTables c
stimes :: b -> CRPTables c -> CRPTables c
$cstimes :: forall c b. Integral b => b -> CRPTables c -> CRPTables c
sconcat :: NonEmpty (CRPTables c) -> CRPTables c
$csconcat :: forall c. NonEmpty (CRPTables c) -> CRPTables c
<> :: CRPTables c -> CRPTables c -> CRPTables c
$c<> :: forall c. CRPTables c -> CRPTables c -> CRPTables c
Semigroup, Semigroup (CRPTables c)
CRPTables c
Semigroup (CRPTables c)
-> CRPTables c
-> (CRPTables c -> CRPTables c -> CRPTables c)
-> ([CRPTables c] -> CRPTables c)
-> Monoid (CRPTables c)
[CRPTables c] -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
forall c. Semigroup (CRPTables c)
forall c. CRPTables c
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall c. [CRPTables c] -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
mconcat :: [CRPTables c] -> CRPTables c
$cmconcat :: forall c. [CRPTables c] -> CRPTables c
mappend :: CRPTables c -> CRPTables c -> CRPTables c
$cmappend :: forall c. CRPTables c -> CRPTables c -> CRPTables c
mempty :: CRPTables c
$cmempty :: forall c. CRPTables c
$cp1Monoid :: forall c. Semigroup (CRPTables c)
Monoid)

-- Initial state of the CRP : one customer sitting at table #0
crpInitial :: CRPTables (Sum Integer)
crpInitial :: CRPTables (Sum Integer)
crpInitial = Int -> CRPTables (Sum Integer) -> CRPTables (Sum Integer)
forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
0 CRPTables (Sum Integer)
forall a. Monoid a => a
mempty

-- Seat one customer at table 'k'
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert :: Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
k (CRP IntMap (Sum a)
ts) = IntMap (Sum a) -> CRPTables (Sum a)
forall c. IntMap c -> CRPTables c
CRP (IntMap (Sum a) -> CRPTables (Sum a))
-> IntMap (Sum a) -> CRPTables (Sum a)
forall a b. (a -> b) -> a -> b
$ (Sum a -> Sum a -> Sum a)
-> Int -> Sum a -> IntMap (Sum a) -> IntMap (Sum a)
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IM.insertWith Sum a -> Sum a -> Sum a
forall a. Semigroup a => a -> a -> a
(<>) Int
k (a -> Sum a
forall a. a -> Sum a
Sum a
1) IntMap (Sum a)
ts




-- | Uniform between two values
uniformR :: Monad m =>
            Double -- ^ low
         -> Double -- ^ high
         -> GenT m Double
uniformR :: Double -> Double -> GenT m Double
uniformR Double
lo Double
hi = Double -> Double
scale (Double -> Double) -> GenT m Double -> GenT m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  where
    scale :: Double -> Double
scale Double
x = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lo

-- | Standard normal distribution
stdNormal :: Monad m => GenT m Double
stdNormal :: GenT m Double
stdNormal = Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
0 Double
1

-- | Uniform in [0, 1)
stdUniform :: Monad m => GenT m Double
stdUniform :: GenT m Double
stdUniform = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (Double, SMGen)
nextDouble

-- | Beta distribution, from two standard uniform samples
beta :: Monad m =>
        Double -- ^ shape parameter \( \alpha \gt 0 \) 
     -> Double -- ^ shape parameter \( \beta \gt 0 \)
     -> GenT m Double
beta :: Double -> Double -> GenT m Double
beta Double
a Double
b = GenT m Double
go
  where
    go :: GenT m Double
go = do
      (Double
y1, Double
y2) <- GenT m (Double, Double)
sample2
      if
        Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
1
        then Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
y1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2))
        else GenT m Double
go
    sample2 :: GenT m (Double, Double)
sample2 = Double -> Double -> (Double, Double)
f (Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> (Double, Double))
-> GenT m Double -> GenT m (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
      where
        f :: Double -> Double -> (Double, Double)
f Double
u1 Double
u2 = (Double
u1 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
a), Double
u2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
b))

-- | Gamma distribution, using Ahrens-Dieter accept-reject (algorithm GD):
--
-- Ahrens, J. H.; Dieter, U (January 1982). "Generating gamma variates by a modified rejection technique". Communications of the ACM. 25 (1): 47–54
gamma :: Monad m =>
         Double -- ^ shape parameter \( k \gt 0 \)
      -> Double -- ^ scale parameter \( \theta \gt 0 \)
      -> GenT m Double
gamma :: Double -> Double -> GenT m Double
gamma Double
k Double
th = do
  Double
xi <- GenT m Double
sampleXi
  [Double]
us <- Int -> GenT m Double -> GenT m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> GenT m Double -> GenT m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform)
  Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double
th Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
us
  where
    sampleXi :: GenT m Double
sampleXi = do
      (Double
xi, Double
eta) <- GenT m (Double, Double)
sample2
      if Double
eta Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi)
        then GenT m Double
sampleXi
        else Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
xi
    (Int
n, Double
delta) = (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
k, Double
k Double -> Double -> Double
forall a. Num a => a -> a -> a
- Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    ee :: Double
ee = Double -> Double
forall a. Floating a => a -> a
exp Double
1
    sample2 :: GenT m (Double, Double)
sample2 = Double -> Double -> Double -> (Double, Double)
f (Double -> Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> Double -> (Double, Double))
-> GenT m Double -> GenT m (Double -> (Double, Double))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform GenT m (Double -> (Double, Double))
-> GenT m Double -> GenT m (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
      where
        f :: Double -> Double -> Double -> (Double, Double)
f Double
u Double
v Double
w
          | Double
u Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
ee Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
ee Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
delta) =
            let xi :: Double
xi = Double
v Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
delta)
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1))
          | Bool
otherwise =
            let xi :: Double
xi = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log Double
v
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi))

-- | Pareto distribution
pareto :: Monad m =>
          Double -- ^ shape parameter \( \alpha \gt 0 \)
       -> Double -- ^ scale parameter \( x_{min} \gt 0 \)
       -> GenT m Double
pareto :: Double -> Double -> GenT m Double
pareto Double
a Double
xmin = do
  Double
y <- Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> GenT m Double
exponential Double
a
  Double -> GenT m Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double
xmin Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp Double
y
{-# INLINABLE pareto #-}

-- | The Dirichlet distribution with the provided concentration parameters.
--   The dimension of the distribution is determined by the number of
--   concentration parameters supplied.
--
--   >>> sample 1234 (dirichlet [0.1, 1, 10])
--   [2.3781130220132788e-11,6.646079701567026e-2,0.9335392029605486]
dirichlet :: (Monad m, Traversable f) =>
             f Double -- ^ concentration parameters \( \gamma_i \gt 0 , \forall i \)
          -> GenT m (f Double)
dirichlet :: f Double -> GenT m (f Double)
dirichlet f Double
as = do
  f Double
zs <- (Double -> GenT m Double) -> f Double -> GenT m (f Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
`gamma` Double
1) f Double
as
  f Double -> GenT m (f Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (f Double -> GenT m (f Double)) -> f Double -> GenT m (f Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> f Double -> f Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ f Double -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum f Double
zs) f Double
zs
{-# INLINABLE dirichlet #-}


-- | Normal distribution
normal :: Monad m =>
          Double -- ^ mean
       -> Double -- ^ standard deviation \( \sigma \gt 0 \)
       -> GenT m Double
normal :: Double -> Double -> GenT m Double
normal Double
mu Double
sig = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig)

-- | Exponential distribution
exponential :: Monad m =>
               Double -- ^ rate parameter \( \lambda > 0 \)
            -> GenT m Double
exponential :: Double -> GenT m Double
exponential Double
l = (SMGen -> (Double, SMGen)) -> GenT m Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Double, SMGen)
exponentialF Double
l)




-- | Log-normal distribution with specified mean and standard deviation.
logNormal :: Monad m =>
             Double
          -> Double -- ^ standard deviation \( \sigma \gt 0 \)
          -> GenT m Double
logNormal :: Double -> Double -> GenT m Double
logNormal Double
m Double
sd = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> GenT m Double -> GenT m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
m Double
sd
{-# INLINABLE logNormal #-}


-- | Laplace or double-exponential distribution with provided location and
--   scale parameters.
laplace :: Monad m =>
           Double -- ^ location parameter
        -> Double  -- ^ scale parameter \( s \gt 0 \)
        -> GenT m Double
laplace :: Double -> Double -> GenT m Double
laplace Double
mu Double
sigma = do
  Double
u <- Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
uniformR (-Double
0.5) Double
0.5
  let b :: Double
b = Double
sigma Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double
forall a. Floating a => a -> a
sqrt Double
2
  Double -> GenT m Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
signum Double
u Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
abs Double
u)
{-# INLINABLE laplace #-}

-- | Weibull distribution with provided shape and scale parameters.
weibull :: Monad m =>
           Double -- ^ shape \( a \gt 0 \)
        -> Double -- ^ scale \( b \gt 0 \)
        -> GenT m Double
weibull :: Double -> Double -> GenT m Double
weibull Double
a Double
b = do
  Double
x <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  Double -> GenT m Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ (- Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
x)) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
b
{-# INLINABLE weibull #-}







-- | Wrap a 'splitmix' PRNG function
withGen :: Monad m =>
           (SMGen -> (a, SMGen)) -- ^ explicit generator passing (e.g. 'nextDouble')
        -> GenT m a
withGen :: (SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (a, SMGen)
f = StateT SMGen m a -> GenT m a
forall (m :: * -> *) a. StateT SMGen m a -> GenT m a
GenT (StateT SMGen m a -> GenT m a) -> StateT SMGen m a -> GenT m a
forall a b. (a -> b) -> a -> b
$ do
  SMGen
gen <- StateT SMGen m SMGen
forall s (m :: * -> *). MonadState s m => m s
get
  let
    (a
b, SMGen
gen') = SMGen -> (a, SMGen)
f SMGen
gen
  SMGen -> StateT SMGen m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put SMGen
gen'
  a -> StateT SMGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b

exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF Double
l SMGen
g = (Double -> Double -> Double
forall a. Floating a => a -> a -> a
exponentialICDF Double
l Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig SMGen
g = (Double -> Double -> Double -> Double
forall a. InvErf a => a -> a -> a -> a
normalICDF Double
mu Double
sig Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p SMGen
g = (Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p , SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g


-- | inverse CDF of normal rv
normalICDF :: InvErf a =>
              a -- ^ mean
           -> a -- ^ std dev
           -> a -> a
normalICDF :: a -> a -> a -> a
normalICDF a
mu a
sig a
p = a
mu a -> a -> a
forall a. Num a => a -> a -> a
+ a
sig a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. InvErf a => a -> a
inverf (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

-- | inverse CDF of exponential rv
exponentialICDF :: Floating a =>
                   a -- ^ rate
                -> a -> a
exponentialICDF :: a -> a -> a
exponentialICDF a
l a
p = (- a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
l) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
p)