{-# LANGUAGE MultiParamTypeClasses, RankNTypes, FlexibleInstances, FlexibleContexts, RecordWildCards, BangPatterns #-} -- |A generic \"ziggurat algorithm\" implementation. Fairly rough right -- now. -- -- There is a lot of room for improvement in 'findBin0' especially. -- It needs a fair amount of cleanup and elimination of redundant -- calculation, as well as either a justification for using the simple -- 'findMinFrom' or a proper root-finding algorithm. -- -- It would also be nice to add (preferably by pulling in an -- external package) support for numerical integration and -- differentiation, so that tables can be derived from only a -- PDF (if the end user is willing to take the performance and -- accuracy hit for the convenience). module Data.Random.Distribution.Ziggurat ( Ziggurat(..) , mkZigguratRec , mkZiggurat , mkZiggurat_ , findBin0 , runZiggurat ) where import Data.Random.Internal.Find import Data.Random.Distribution.Uniform import Data.Random.Distribution import Data.Random.RVar import Data.Vector.Generic as Vec import qualified Data.Vector as V import qualified Data.Vector.Unboxed as UV -- |A data structure containing all the data that is needed -- to implement Marsaglia & Tang's \"ziggurat\" algorithm for -- sampling certain kinds of random distributions. -- -- The documentation here is probably not sufficient to tell a user exactly -- how to build one of these from scratch, but it is not really intended to -- be. There are several helper functions that will build 'Ziggurat's. -- The pathologically curious may wish to read the 'runZiggurat' source. -- That is the ultimate specification of the semantics of all these fields. data Ziggurat v t = Ziggurat { -- |The X locations of each bin in the distribution. Bin 0 is the -- 'infinite' one. -- -- In the case of bin 0, the value given is sort of magical - x[0] is -- defined to be V/f(R). It's not actually the location of any bin, -- but a value computed to make the algorithm more concise and slightly -- faster by not needing to specially-handle bin 0 quite as often. -- If you really need to know why it works, see the 'runZiggurat' -- source or \"the literature\" - it's a fairly standard setup. zTable_xs :: !(v t), -- |The ratio of each bin's Y value to the next bin's Y value zTable_y_ratios :: !(v t), -- |The Y value (zFunc x) of each bin zTable_ys :: !(v t), -- |An RVar providing a random tuple consisting of: -- -- * a bin index, uniform over [0,c) :: Int (where @c@ is the -- number of bins in the tables) -- -- * a uniformly distributed fractional value, from -1 to 1 -- if not mirrored, from 0 to 1 otherwise. -- -- This is provided as a single 'RVar' because it can be implemented -- more efficiently than naively sampling 2 separate values - a -- single random word (64 bits) can be efficiently converted to -- a double (using 52 bits) and a bin number (using up to 12 bits), -- for example. zGetIU :: !(forall m. RVarT m (Int, t)), -- |The distribution for the final \"virtual\" bin -- (the ziggurat algorithm does not handle distributions -- that wander off to infinity, so another distribution is needed -- to handle the last \"bin\" that stretches to infinity) zTailDist :: (forall m. RVarT m t), -- |A copy of the uniform RVar generator for the base type, -- so that @Distribution Uniform t@ is not needed when sampling -- from a Ziggurat (makes it a bit more self-contained). zUniform :: !(forall m. t -> t -> RVarT m t), -- |The (one-sided antitone) PDF, not necessarily normalized zFunc :: !(t -> t), -- |A flag indicating whether the distribution should be -- mirrored about the origin (the ziggurat algorithm in -- its native form only samples from one-sided distributions. -- By mirroring, we can extend it to symmetric distributions -- such as the normal distribution) zMirror :: !Bool } -- |Sample from the distribution encoded in a 'Ziggurat' data structure. {-# INLINE runZiggurat #-} {-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Float -> RVarT m Float #-} {-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Double -> RVarT m Double #-} {-# SPECIALIZE runZiggurat :: Ziggurat V.Vector Float -> RVarT m Float #-} {-# SPECIALIZE runZiggurat :: Ziggurat V.Vector Double -> RVarT m Double #-} runZiggurat :: (Num a, Ord a, Vector v a) => Ziggurat v a -> RVarT m a runZiggurat !Ziggurat{..} = go where {-# NOINLINE go #-} go = do -- Select a bin (I) and a uniform value (U) from -1 to 1 -- (or 0 to 1 if not mirroring the distribution). -- Let X be U scaled to the size of the selected bin. (!i,!u) <- zGetIU -- if the uniform value U falls in the area "clearly inside" the -- bin, accept X immediately. -- Otherwise, depending on the bin selected, use either the -- tail distribution or an accept/reject test. if abs u < zTable_y_ratios ! i then return $! (u * zTable_xs ! i) else if i == 0 then sampleTail u else sampleGreyArea i $! (u * zTable_xs ! i) -- when the sample falls in the "grey area" (the area between -- the Y values of the selected bin and the bin after that one), -- use an accept/reject method based on the target PDF. {-# INLINE sampleGreyArea #-} sampleGreyArea i x = do !v <- zUniform (zTable_ys ! (i+1)) (zTable_ys ! i) if v < zFunc (abs x) then return $! x else go -- if the selected bin is the "infinite" one, call it quits and -- defer to the tail distribution (mirroring if needed to ensure -- the result has the sign already selected by zGetIU) {-# INLINE sampleTail #-} sampleTail x | zMirror && x < 0 = fmap negate zTailDist | otherwise = zTailDist -- |Build the tables to implement the \"ziggurat algorithm\" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- -- * the (one-sided antitone) PDF, not necessarily normalized -- -- * the inverse of the PDF -- -- * the number of bins -- -- * R, the x value of the first bin -- -- * V, the volume of each bin -- -- * an RVar providing the 'zGetIU' random tuple -- -- * an RVar sampling from the tail (the region where x > R) -- {-# INLINE mkZiggurat_ #-} {-# SPECIALIZE mkZiggurat_ :: Bool -> (Float -> Float) -> (Float -> Float) -> Int -> Float -> Float -> (forall m. RVarT m (Int, Float)) -> (forall m. RVarT m Float ) -> Ziggurat UV.Vector Float #-} {-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat UV.Vector Double #-} {-# SPECIALIZE mkZiggurat_ :: Bool -> (Float -> Float) -> (Float -> Float) -> Int -> Float -> Float -> (forall m. RVarT m (Int, Float)) -> (forall m. RVarT m Float ) -> Ziggurat V.Vector Float #-} {-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat V.Vector Double #-} mkZiggurat_ :: (RealFloat t, Vector v t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> Int -> t -> t -> (forall m. RVarT m (Int, t)) -> (forall m. RVarT m t) -> Ziggurat v t mkZiggurat_ m f fInv c r v getIU tailDist = Ziggurat { zTable_xs = xs , zTable_y_ratios = precomputeRatios xs , zTable_ys = Vec.map f xs , zGetIU = getIU , zUniform = uniformT , zFunc = f , zTailDist = tailDist , zMirror = m } where xs = zigguratTable f fInv c r v -- |Build the tables to implement the \"ziggurat algorithm\" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments are the same as for 'mkZigguratRec', with an additional -- argument for the tail distribution as a function of the selected -- R value. mkZiggurat :: (RealFloat t, Vector v t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> (t -> t) -> t -> Int -> (forall m. RVarT m (Int, t)) -> (forall m. t -> RVarT m t) -> Ziggurat v t mkZiggurat m f fInv fInt fVol c getIU tailDist = mkZiggurat_ m f fInv c r v getIU (tailDist r) where (r,v) = findBin0 c f fInv fInt fVol -- |Build a lazy recursive ziggurat. Uses a lazily-constructed ziggurat -- as its tail distribution (with another as its tail, ad nauseum). -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- -- * the (one-sided antitone) PDF, not necessarily normalized -- -- * the inverse of the PDF -- -- * the integral of the PDF (definite, from 0) -- -- * the estimated volume under the PDF (from 0 to +infinity) -- -- * the chunk size (number of bins in each layer). 64 seems to -- perform well in practice. -- -- * an RVar providing the 'zGetIU' random tuple -- mkZigguratRec :: (RealFloat t, Vector v t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> (t -> t) -> t -> Int -> (forall m. RVarT m (Int, t)) -> Ziggurat v t mkZigguratRec m f fInv fInt fVol c getIU = z where fix :: ((forall m. a -> RVarT m a) -> (forall m. a -> RVarT m a)) -> (forall m. a -> RVarT m a) fix f = f (fix f) z = mkZiggurat m f fInv fInt fVol c getIU (fix (mkTail m f fInv fInt fVol c getIU z)) mkTail :: (RealFloat a, Vector v a, Distribution Uniform a) => Bool -> (a -> a) -> (a -> a) -> (a -> a) -> a -> Int -> (forall m. RVarT m (Int, a)) -> Ziggurat v a -> (forall m. a -> RVarT m a) -> (forall m. a -> RVarT m a) mkTail m f fInv fInt fVol c getIU typeRep nextTail r = do x <- rvarT (mkZiggurat m f' fInv' fInt' fVol' c getIU nextTail `asTypeOf` typeRep) return (x + r * signum x) where fIntR = fInt r f' x | x < 0 = f r | otherwise = f (x+r) fInv' = subtract r . fInv fInt' x | x < 0 = 0 | otherwise = fInt (x+r) - fIntR fVol' = fVol - fIntR zigguratTable :: (Fractional a, Vector v a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> v a zigguratTable f fInv c r v = case zigguratXs f fInv c r v of (xs, _excess) -> fromList xs zigguratExcess :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> a zigguratExcess f fInv c r v = snd (zigguratXs f fInv c r v) zigguratXs :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a) zigguratXs f fInv c r v = (xs, excess) where xs = Prelude.map x [0..c] -- sample c x ys = Prelude.map f xs x 0 = v / f r x 1 = r x i | i == c = 0 x i | i > 1 = next (i-1) x _ = error "zigguratXs: programming error! this case should be impossible!" next i = let x_i = xs!!i in if x_i <= 0 then -1 else fInv (ys!!i + (v / x_i)) excess = xs!!(c-1) * (f 0 - ys !! (c-1)) - v precomputeRatios :: (Vector v a, Fractional a) => v a -> v a precomputeRatios zTable_xs = generate (c-1) $ \i -> zTable_xs!(i+1) / zTable_xs!i where c = Vec.length zTable_xs -- |I suspect this isn't completely right, but it works well so far. -- Search the distribution for an appropriate R and V. -- -- Arguments: -- -- * Number of bins -- -- * target function (one-sided antitone PDF, not necessarily normalized) -- -- * function inverse -- -- * function definite integral (from 0 to _) -- -- * estimate of total volume under function (integral from 0 to infinity) -- -- Result: (R,V) findBin0 :: (RealFloat b) => Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b) findBin0 cInt f fInv fInt fVol = (rMin,v rMin) where c = fromIntegral cInt v r = r * f r + fVol - fInt r -- initial R guess: r0 = findMin (\r -> v r <= fVol / c) -- find a better R: rMin = findMinFrom r0 1 $ \r -> let e = exc r in e >= 0 && not (isNaN e) exc x = zigguratExcess f fInv cInt x (v x) instance (Num t, Ord t, Vector v t) => Distribution (Ziggurat v) t where rvar = runZiggurat