{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts, RecordWildCards #-} -- |A generic \"ziggurat algorithm\" implementation. Fairly rough right -- now. -- -- There is a lot of room for improvement in 'findBin0' especially. -- It needs a fair amount of cleanup and elimination of redundant -- calculation, as well as either a justification for using the simple -- 'findMinFrom' or a proper root-finding algorithm. -- -- It would also be nice to add (preferably by pulling in an -- external package) support for numerical integration and -- differentiation, so that tables can be derived from only a -- PDF (if the end user is willing to take the performance and -- accuracy hit for the convenience). module Data.Random.Distribution.Ziggurat ( Ziggurat(..) , mkZigguratRec , mkZiggurat , mkZiggurat_ , findBin0 , runZiggurat ) where import Data.Random.Internal.Find import Data.Random.Distribution.Uniform import Data.Random.Distribution import Data.Random.RVar import Data.StorableVector as Vec import Foreign.Storable vec ! i = index vec i -- |A data structure containing all the data that is needed -- to implement Marsaglia & Tang's \"ziggurat\" algorithm for -- sampling certain kinds of random distributions. -- -- The documentation here is probably not sufficient to tell a user exactly -- how to build one of these from scratch, but it is not really intended to -- be. There are several helper functions that will build 'Ziggurat's. -- The pathologically curious may wish to read the 'runZiggurat' source. -- That is the ultimate specification of the semantics of all these fields. data Ziggurat t = Ziggurat { -- |The X locations of each bin in the distribution. Bin 0 is the -- 'infinite' one. -- -- In the case of bin 0, the value given is sort of magical - x[0] is -- defined to be V/f(R). It's not actually the location of any bin, -- but a value computed to make the algorithm more concise and slightly -- faster by not needing to specially-handle bin 0 quite as often. -- If you really need to know why it works, see the 'runZiggurat' -- source or \"the literature\" - it's a fairly standard setup. zTable_xs :: Vector t, -- |The ratio of each bin's Y value to the next bin's Y value zTable_x_ratios :: Vector t, -- |The Y value (zFunc x) of each bin zTable_ys :: Vector t, -- |An RVar providing a random tuple consisting of: -- -- * a bin index, uniform over [0,c) :: Int (where @c@ is the -- number of bins in the tables) -- -- * a uniformly distributed fractional value, from -1 to 1 -- if not mirrored, from 0 to 1 otherwise. -- -- This is provided as a single 'RVar' because it can be implemented -- more efficiently than naively sampling 2 separate values - a -- single random word (64 bits) can be efficiently converted to -- a double (using 52 bits) and a bin number (using up to 12 bits), -- for example. zGetIU :: RVar (Int, t), -- |The distribution for the final \"virtual\" bin -- (the ziggurat algorithm does not handle distributions -- that wander off to infinity, so another distribution is needed -- to handle the last \"bin\" that stretches to infinity) zTailDist :: RVar t, -- |A copy of the uniform RVar generator for the base type, -- so that @Distribution Uniform t@ is not needed when sampling -- from a Ziggurat (makes it a bit more self-contained). zUniform :: t -> t -> RVar t, -- |The (one-sided antitone) PDF, not necessarily normalized zFunc :: t -> t, -- |A flag indicating whether the distribution should be -- mirrored about the origin (the ziggurat algorithm it -- its native form only samples from one-sided distributions. -- By mirroring, we can extend it to symmetric distributions -- such as the normal distribution) zMirror :: Bool } -- |Sample from the distribution encoded in a 'Ziggurat' data structure. {-# INLINE runZiggurat #-} runZiggurat :: (Num a, Ord a, Storable a) => Ziggurat a -> RVar a runZiggurat Ziggurat{..} = go where go = do -- Select a bin (I) and a uniform value (U) from -1 to 1 -- (or 0 to 1 if not mirroring the distribution). -- Let X be U scaled to the size of the selected bin. (i,u) <- zGetIU let x = u * zTable_xs ! i -- if the uniform value U falls in the area "clearly inside" the -- bin, accept X immediately. -- Otherwise, depending on the bin selected, use either the -- tail distribution or an accept/reject test. if abs u < zTable_x_ratios ! i then return $! x else if i == 0 then sampleTail x else sampleGreyArea i x -- when the sample falls in the "grey area" (the area between -- the Y values of the selected bin and the bin after that one), -- use an accept/reject method based on the target PDF. sampleGreyArea i x = do v <- zUniform (zTable_ys ! (i+1)) (zTable_ys ! i) if v < zFunc (abs x) then return $! x else go -- if the selected bin is the "infinite" one, call it quits and -- defer to the tail distribution (mirroring if needed to ensure -- the result has the sign already selected by zGetIU) sampleTail x | x < 0 = fmap negate zTailDist | otherwise = zTailDist -- |Build the tables to implement the \"ziggurat algorithm\" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- -- * the (one-sided antitone) PDF, not necessarily normalized -- -- * the inverse of the PDF -- -- * the number of bins -- -- * R, the x value of the first bin -- -- * V, the volume of each bin -- -- * an RVar providing the 'zGetIU' random tuple -- -- * an RVar sampling from the tail (the region where x > R) -- mkZiggurat_ :: (RealFloat t, Storable t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> Int -> t -> t -> RVar (Int, t) -> RVar t -> Ziggurat t mkZiggurat_ m f fInv c r v getIU tailDist = z where z = Ziggurat { zTable_xs = zigguratTable f fInv c r v , zTable_x_ratios = precomputeRatios (zTable_xs z) , zTable_ys = Vec.map f (zTable_xs z) , zGetIU = getIU , zUniform = uniform , zFunc = f , zTailDist = tailDist , zMirror = m } -- |Build the tables to implement the \"ziggurat algorithm\" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments are the same as for 'mkZigguratRec', with an additional -- argument for the tail distribution as a function of the selected -- R value. mkZiggurat :: (RealFloat t, Storable t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> (t -> t) -> t -> Int -> RVar (Int, t) -> (t -> RVar t) -> Ziggurat t mkZiggurat m f fInv fInt fVol c getIU tailDist = mkZiggurat_ m f fInv c r v getIU (tailDist r) where (r,v) = findBin0 c f fInv fInt fVol -- |Build a lazy recursive ziggurat. Uses a lazily-constructed ziggurat -- as its tail distribution (with another as its tail, ad nauseum). -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- -- * the (one-sided antitone) PDF, not necessarily normalized -- -- * the inverse of the PDF -- -- * the integral of the PDF (definite, from 0) -- -- * the estimated volume under the PDF (from 0 to +infinity) -- -- * the chunk size (number of bins in each layer). 64 seems to -- perform well in practice. -- -- * an RVar providing the 'zGetIU' random tuple -- mkZigguratRec :: (RealFloat t, Storable t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> (t -> t) -> t -> Int -> RVar (Int, t) -> Ziggurat t mkZigguratRec m f fInv fInt fVol c getIU = mkZiggurat m f fInv fInt fVol c getIU (fix (mkTail m f fInv fInt fVol c getIU)) where fix f = f (fix f) mkTail m f fInv fInt fVol c getIU nextTail r = do x <- rvar (mkZiggurat m f' fInv' fInt' fVol' c getIU nextTail) return (x + r * signum x) where fIntR = fInt r f' x | x < 0 = f r | otherwise = f (x+r) fInv' = subtract r . fInv fInt' x | x < 0 = 0 | otherwise = fInt (x+r) - fIntR fVol' = fVol - fIntR zigguratTable :: (Fractional a, Storable a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> Vector a zigguratTable f fInv c r v = case zigguratXs f fInv c r v of (xs, excess) -> pack xs where epsilon = 1e-3*v zigguratExcess f fInv c r v = snd (zigguratXs f fInv c r v) zigguratXs f fInv c r v = (xs, excess) where xs = Prelude.map x [0..c] -- sample c x ys = Prelude.map f xs x 0 = v / f r x 1 = r x i | i == c = 0 x (i+1) = next i next i = let x_i = xs!!i in if x_i <= 0 then -1 else fInv (ys!!i + (v / x_i)) excess = xs!!(c-1) * (f 0 - ys !! (c-1)) - v precomputeRatios zTable_xs = sample (c-1) $ \i -> zTable_xs!(i+1) / zTable_xs!i where c = Vec.length zTable_xs -- |I suspect this isn't completely right, but it works well so far. -- Search the distribution for an appropriate R and V. -- -- Arguments: -- -- * Number of bins -- -- * target function (one-sided antitone PDF, not necessarily normalized) -- -- * function inverse -- -- * function definite integral (from 0 to _) -- -- * estimate of total volume under function (integral from 0 to infinity) -- -- Result: (R,V) findBin0 :: (RealFloat b) => Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b) findBin0 cInt f fInv fInt fVol = (r,v r) where c = fromIntegral cInt v r = r * f r + fVol - fInt r -- initial R guess: r0 = findMin (\r -> v r <= fVol / c) -- find a better R: r = findMinFrom r0 1 $ \r -> let e = exc r in e >= 0 && not (isNaN e) exc x = zigguratExcess f fInv cInt x (v x) instance (Num t, Ord t, Storable t) => Distribution Ziggurat t where rvar = runZiggurat