{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{- | An IO-based sampling monad.
-}

module Sampler (
  -- * Sampler monad
    Sampler
  , liftS
  , sampleIO
  , sampleIOFixed
  , createSampler
  -- * Sampling functions
  -- $Sampling-functions
  , sampleRandom
  , sampleCauchy
  , sampleNormal
  , sampleUniform
  , sampleDiscreteUniform
  , sampleGamma
  , sampleBeta
  , sampleBernoulli
  , sampleBinomial
  , sampleCategorical
  , sampleDiscrete
  , samplePoisson
  , sampleDirichlet
  ) where

import Control.Monad ( replicateM )
import Control.Monad.Trans (MonadIO, MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT, ask, mapReaderT, runReaderT)
import Data.Map (Map)
import GHC.Word ( Word32 )
import qualified Data.Vector as V
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import qualified System.Random.MWC.Probability as MWC.Probability
import Statistics.Distribution ( ContGen(genContVar) )
import Statistics.Distribution.CauchyLorentz ( cauchyDistribution )
import System.Random.MWC ( initialize )

-- | Sampler type, for running IO computations alongside a random number generator
newtype Sampler a = Sampler {forall a. Sampler a -> ReaderT GenIO IO a
runSampler :: ReaderT MWC.GenIO IO a}
  deriving ((forall a b. (a -> b) -> Sampler a -> Sampler b)
-> (forall a b. a -> Sampler b -> Sampler a) -> Functor Sampler
forall a b. a -> Sampler b -> Sampler a
forall a b. (a -> b) -> Sampler a -> Sampler b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Sampler b -> Sampler a
$c<$ :: forall a b. a -> Sampler b -> Sampler a
fmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
$cfmap :: forall a b. (a -> b) -> Sampler a -> Sampler b
Functor, Functor Sampler
Functor Sampler
-> (forall a. a -> Sampler a)
-> (forall a b. Sampler (a -> b) -> Sampler a -> Sampler b)
-> (forall a b c.
    (a -> b -> c) -> Sampler a -> Sampler b -> Sampler c)
-> (forall a b. Sampler a -> Sampler b -> Sampler b)
-> (forall a b. Sampler a -> Sampler b -> Sampler a)
-> Applicative Sampler
forall a. a -> Sampler a
forall a b. Sampler a -> Sampler b -> Sampler a
forall a b. Sampler a -> Sampler b -> Sampler b
forall a b. Sampler (a -> b) -> Sampler a -> Sampler b
forall a b c. (a -> b -> c) -> Sampler a -> Sampler b -> Sampler c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Sampler a -> Sampler b -> Sampler a
$c<* :: forall a b. Sampler a -> Sampler b -> Sampler a
*> :: forall a b. Sampler a -> Sampler b -> Sampler b
$c*> :: forall a b. Sampler a -> Sampler b -> Sampler b
liftA2 :: forall a b c. (a -> b -> c) -> Sampler a -> Sampler b -> Sampler c
$cliftA2 :: forall a b c. (a -> b -> c) -> Sampler a -> Sampler b -> Sampler c
<*> :: forall a b. Sampler (a -> b) -> Sampler a -> Sampler b
$c<*> :: forall a b. Sampler (a -> b) -> Sampler a -> Sampler b
pure :: forall a. a -> Sampler a
$cpure :: forall a. a -> Sampler a
Applicative, Applicative Sampler
Applicative Sampler
-> (forall a b. Sampler a -> (a -> Sampler b) -> Sampler b)
-> (forall a b. Sampler a -> Sampler b -> Sampler b)
-> (forall a. a -> Sampler a)
-> Monad Sampler
forall a. a -> Sampler a
forall a b. Sampler a -> Sampler b -> Sampler b
forall a b. Sampler a -> (a -> Sampler b) -> Sampler b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> Sampler a
$creturn :: forall a. a -> Sampler a
>> :: forall a b. Sampler a -> Sampler b -> Sampler b
$c>> :: forall a b. Sampler a -> Sampler b -> Sampler b
>>= :: forall a b. Sampler a -> (a -> Sampler b) -> Sampler b
$c>>= :: forall a b. Sampler a -> (a -> Sampler b) -> Sampler b
Monad)

-- | Lift an @IO@ computation into @Sampler@
liftS :: IO a -> Sampler a
liftS :: forall a. IO a -> Sampler a
liftS IO a
f = ReaderT GenIO IO a -> Sampler a
forall a. ReaderT GenIO IO a -> Sampler a
Sampler (ReaderT GenIO IO a -> Sampler a)
-> ReaderT GenIO IO a -> Sampler a
forall a b. (a -> b) -> a -> b
$ IO a -> ReaderT (Gen RealWorld) IO a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IO a
f

-- | Takes a @Sampler@, provides it a random generator, and runs the sampler in the @IO@ context
sampleIO :: Sampler a -> IO a
sampleIO :: forall a. Sampler a -> IO a
sampleIO Sampler a
m = IO (Gen RealWorld)
IO GenIO
MWC.createSystemRandom IO (Gen RealWorld) -> (Gen RealWorld -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a)
-> (Sampler a -> ReaderT (Gen RealWorld) IO a)
-> Sampler a
-> Gen RealWorld
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler a -> ReaderT (Gen RealWorld) IO a
forall a. Sampler a -> ReaderT GenIO IO a
runSampler) Sampler a
m

-- | Takes a @Sampler@, provides it a fixed generator, and runs the sampler in the @IO@ context
sampleIOFixed :: Sampler a -> IO a
sampleIOFixed :: forall a. Sampler a -> IO a
sampleIOFixed Sampler a
m = IO (Gen RealWorld)
forall (m :: * -> *). PrimMonad m => m (Gen (PrimState m))
MWC.create IO (Gen RealWorld) -> (Gen RealWorld -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a)
-> (Sampler a -> ReaderT (Gen RealWorld) IO a)
-> Sampler a
-> Gen RealWorld
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler a -> ReaderT (Gen RealWorld) IO a
forall a. Sampler a -> ReaderT GenIO IO a
runSampler) Sampler a
m

-- | Takes a @Sampler@, provides it a custom fixed generator, and runs the sampler in the @IO@ context
sampleIOCustom :: Int -> Sampler a -> IO a
sampleIOCustom :: forall a. Int -> Sampler a -> IO a
sampleIOCustom Int
n Sampler a
m = Vector Word32 -> IO GenIO
forall (m :: * -> *) (v :: * -> *).
(PrimMonad m, Vector v Word32) =>
v Word32 -> m (Gen (PrimState m))
initialize (Word32 -> Vector Word32
forall a. a -> Vector a
V.singleton (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Word32)) IO (Gen RealWorld) -> (Gen RealWorld -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a)
-> (Sampler a -> ReaderT (Gen RealWorld) IO a)
-> Sampler a
-> Gen RealWorld
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler a -> ReaderT (Gen RealWorld) IO a
forall a. Sampler a -> ReaderT GenIO IO a
runSampler) Sampler a
m

-- | Takes a distribution which awaits a generator, and returns a @Sampler@
createSampler :: (MWC.GenIO -> IO a) -> Sampler a
createSampler :: forall a. (GenIO -> IO a) -> Sampler a
createSampler GenIO -> IO a
f = ReaderT GenIO IO a -> Sampler a
forall a. ReaderT GenIO IO a -> Sampler a
Sampler (ReaderT GenIO IO a -> Sampler a)
-> ReaderT GenIO IO a -> Sampler a
forall a b. (a -> b) -> a -> b
$ ReaderT (Gen RealWorld) IO (Gen RealWorld)
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask ReaderT (Gen RealWorld) IO (Gen RealWorld)
-> (Gen RealWorld -> ReaderT (Gen RealWorld) IO a)
-> ReaderT (Gen RealWorld) IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO a -> ReaderT (Gen RealWorld) IO a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO a -> ReaderT (Gen RealWorld) IO a)
-> (Gen RealWorld -> IO a)
-> Gen RealWorld
-> ReaderT (Gen RealWorld) IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gen RealWorld -> IO a
GenIO -> IO a
f

{- $Sampling-functions
  Given their distribution parameters, these functions await a generator and
  then sample a value from the distribution in the @IO@ monad.
-}

sampleRandom
  :: MWC.GenIO
  -> IO Double
sampleRandom :: GenIO -> IO Double
sampleRandom = \GenIO
gen -> GenIO -> IO Double
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
MWC.uniform GenIO
gen

sampleCauchy
  :: Double -- ^ location
  -> Double -- ^ scale
  -> (MWC.GenIO -> IO Double)
sampleCauchy :: Double -> Double -> GenIO -> IO Double
sampleCauchy Double
μ Double
σ = \GenIO
gen -> CauchyDistribution -> Gen RealWorld -> IO Double
forall d g (m :: * -> *).
(ContGen d, StatefulGen g m) =>
d -> g -> m Double
genContVar (Double -> Double -> CauchyDistribution
cauchyDistribution Double
μ Double
σ) Gen RealWorld
GenIO
gen

sampleNormal
  :: Double -- ^ mean
  -> Double -- ^ standard deviation
  -> (MWC.GenIO -> IO Double)
sampleNormal :: Double -> Double -> GenIO -> IO Double
sampleNormal Double
μ Double
σ = \GenIO
gen -> Double -> Double -> Gen RealWorld -> IO Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.normal Double
μ Double
σ Gen RealWorld
GenIO
gen

sampleUniform
  :: Double -- ^ lower-bound
  -> Double -- ^ upper-bound
  -> (MWC.GenIO -> IO Double)
sampleUniform :: Double -> Double -> GenIO -> IO Double
sampleUniform Double
min Double
max = \GenIO
gen -> (Double, Double) -> GenIO -> IO Double
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
MWC.uniformR (Double
min, Double
max) GenIO
gen

sampleDiscreteUniform
  :: Int -- ^ lower-bound
  -> Int -- ^ upper-bound
  -> (MWC.GenIO -> IO Int)
sampleDiscreteUniform :: Int -> Int -> GenIO -> IO Int
sampleDiscreteUniform Int
min Int
max = \GenIO
gen -> (Int, Int) -> GenIO -> IO Int
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
MWC.uniformR (Int
min, Int
max) GenIO
gen

sampleGamma
  :: Double -- ^ shape k
  -> Double -- ^ scale θ
  -> (MWC.GenIO -> IO Double)
sampleGamma :: Double -> Double -> GenIO -> IO Double
sampleGamma Double
k Double
θ = \GenIO
gen -> Double -> Double -> Gen RealWorld -> IO Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.gamma Double
k Double
θ Gen RealWorld
GenIO
gen

sampleBeta
  :: Double -- ^ shape α
  -> Double -- ^ shape β
  -> (MWC.GenIO -> IO Double)
sampleBeta :: Double -> Double -> GenIO -> IO Double
sampleBeta Double
α Double
β = \GenIO
gen -> Double -> Double -> Gen RealWorld -> IO Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.beta Double
α Double
β Gen RealWorld
GenIO
gen

sampleBernoulli
  :: Double -- ^ probability of @True@
  -> (MWC.GenIO -> IO Bool)
sampleBernoulli :: Double -> GenIO -> IO Bool
sampleBernoulli Double
p = \GenIO
gen -> Double -> Gen RealWorld -> IO Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.Dist.bernoulli Double
p Gen RealWorld
GenIO
gen

sampleBinomial
  :: Int    -- ^ number of trials
  -> Double -- ^ probability of successful trial
  -> (MWC.GenIO -> IO [Bool])
sampleBinomial :: Int -> Double -> GenIO -> IO [Bool]
sampleBinomial Int
n Double
p = \GenIO
gen -> Int -> IO Bool -> IO [Bool]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Gen RealWorld -> IO Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.Dist.bernoulli Double
p Gen RealWorld
GenIO
gen)

sampleCategorical
  :: V.Vector Double -- ^ probabilities
  -> (MWC.GenIO -> IO Int)
sampleCategorical :: Vector Double -> GenIO -> IO Int
sampleCategorical Vector Double
ps =  \GenIO
gen -> Vector Double -> Gen RealWorld -> IO Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.Dist.categorical (Vector Double
ps) Gen RealWorld
GenIO
gen

sampleDiscrete
  :: [Double] -- ^ probabilities
  -> (MWC.GenIO -> IO Int)
sampleDiscrete :: [Double] -> GenIO -> IO Int
sampleDiscrete [Double]
ps = \GenIO
gen -> Vector Double -> Gen RealWorld -> IO Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.Dist.categorical ([Double] -> Vector Double
forall a. [a] -> Vector a
V.fromList [Double]
ps) Gen RealWorld
GenIO
gen

samplePoisson
  :: Double   -- ^ rate λ
  -> (MWC.GenIO -> IO Int)
samplePoisson :: Double -> GenIO -> IO Int
samplePoisson Double
λ = \GenIO
gen -> Prob IO Int -> GenIO -> IO Int
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.Probability.sample (Double -> Prob IO Int
forall (m :: * -> *). PrimMonad m => Double -> Prob m Int
MWC.Probability.poisson Double
λ) GenIO
gen

sampleDirichlet
  :: [Double] -- ^ concentrations
  -> (MWC.GenIO -> IO [Double])
sampleDirichlet :: [Double] -> GenIO -> IO [Double]
sampleDirichlet [Double]
xs = \GenIO
gen -> [Double] -> Gen RealWorld -> IO [Double]
forall g (m :: * -> *) (t :: * -> *).
(StatefulGen g m, Traversable t) =>
t Double -> g -> m (t Double)
MWC.Dist.dirichlet [Double]
xs Gen RealWorld
GenIO
gen