module Data.Stochastic.Types (
Distribution (..)
, Sampleable (..)
, Sample (..)
, StochProcess (..)
, Sampler (..)
, Mean (..)
, StDev (..)
, marsagliaTsang
) where
import Control.Monad
import Control.Monad.State
import Control.Monad.Writer
import Data.Stochastic.Internal
import qualified Data.Sequence as S
import System.Random
data Distribution a where
Normal :: Mean -> StDev -> Distribution Double
Bernoulli :: Double -> Distribution Bool
Discrete :: [(a, Double)] -> Distribution a
DiscreteUniform :: [a] -> Distribution a
Uniform :: Distribution Double
Certain :: a -> Distribution a
Gamma :: Double -> Double -> Distribution Double
Beta :: Double -> Double -> Distribution Double
class Sampleable d where
certainDist :: a -> d a
sampleFrom :: (RandomGen g) => d a -> g -> (a, g)
instance Sampleable Distribution where
sampleFrom da g
= case da of
Normal mean stdev
-> let (a, g') = closedRnd g
(a', g'') = closedRnd g'
s = (stdev * (boxMuller a a')) + mean
in (s, g')
Bernoulli prob
-> let (a, g') = closedRnd g
in (a <= prob, g')
Discrete []
-> error "cannot sample from empty discrete distribution"
Discrete l
-> let (a, g') = closedRnd g
in (scan a l, g')
where scan lim [] =
if lim <= 0 then error $ "not normalized discrete dist"
else error "empty discrete dist"
scan lim (x:xs) =
if lim <= snd x then fst x
else scan (lim snd x) xs
DiscreteUniform []
-> error "cannot sample from empty discrete distribution"
DiscreteUniform l
-> let (a, g') = closedRnd g
prob = 1 / (fromIntegral $ length l)
in (l !! (floor $ a / prob), g')
Uniform
-> closedRnd g
Gamma alpha beta
-> if alpha <= 0 || beta <= 0 then error "alpha and beta parameter cannot be less than or equal to zero in beta distribution"
else if alpha > 0 && alpha < 1 then
let (a, g') = sampleFrom (Gamma (alpha + 1) beta) g
(uni, g'') = openRnd g'
in (a * (uni ** (1/alpha)), g'')
else let d = alpha (1/3)
c = 1 / sqrt (9 * d)
(m, g') = marsagliaTsang d c g
in (m * beta, g')
Beta alpha beta
-> let (x, g') = sampleFrom (Gamma alpha 1) g
(y, g'') = sampleFrom (Gamma beta 1) g'
in (x / (x + y), g'')
Certain val
-> (val, snd $ openRnd g)
certainDist = Certain
marsagliaTsang :: (RandomGen g) => Double -> Double -> g -> (Double, g)
marsagliaTsang d c g =
let (norm, g') = sampleFrom (Normal 0 1) g
(uni, g'') = sampleFrom Uniform g'
v = (1 + (c * norm)) ** 3
in if norm > ((1)/c) &&
log uni < ((norm ** 2)/2 + d (d * v) + (d * log v))
then (d * v, g'')
else marsagliaTsang d c g''
instance (Show a) => Show (Distribution a) where
show da = case da of
Normal mean stdev -> "Normal " ++ show mean ++ " " ++ show stdev
Bernoulli prob -> "Bernoulli " ++ show prob
Discrete l -> "Discrete " ++ show l
DiscreteUniform l -> "DiscreteUniform " ++ show l
Certain val -> "Certain " ++ show val
newtype Sample g d a
= Sample { runSample :: (RandomGen g, Sampleable d) => State g (d a) }
type StochProcess
= WriterT (S.Seq Double) (Sample StdGen Distribution) Double
instance (RandomGen g, Sampleable d) => Monad (Sample g d) where
return x = Sample $ do
modify (snd . next)
return $ certainDist x
(>>=) ma f = Sample $ do
modify (snd . next)
dist <- runSample ma
g <- get
let a = fst $ sampleFrom dist g
runSample (f a)
instance (RandomGen g, Sampleable s) => Functor (Sample g s) where
fmap = liftM
instance (RandomGen g, Sampleable s) => Applicative (Sample g s) where
pure = return
(<*>) = ap
type Sampler a = Sample StdGen Distribution a
type Mean = Double
type StDev = Double