module Language.Hakaru.Sampler (Sampler, deterministic, sbind, smap) where
import Language.Hakaru.Mixture (Mixture, mnull, empty, scale, point)
import Language.Hakaru.Distribution (choose)
import System.Random (RandomGen)
type Sampler a = forall g. (RandomGen g) => g -> (Mixture a, g)
deterministic :: Mixture a -> Sampler a
deterministic m g = (m, g)
sbind :: Sampler a -> (a -> Sampler b) -> Sampler b
sbind s k g0 =
case s g0 of { (m1, g1) ->
if mnull m1 then (empty, g1) else
case choose m1 g1 of { (a, v, g2) ->
case k a g2 of { (m2, g) ->
(scale v m2, g) } } }
smap :: (a -> b) -> Sampler a -> Sampler b
smap f s = sbind s (\a -> deterministic (point (f a) 1))