{- - ``Data/Random/Distribution/Ziggurat'' - A generic "ziggurat algorithm" implementation. Fairly rough right - now. - - There is a lot of room for improvement in 'findBin0' especially. - It needs a fair amount of cleanup and elimination of redundant - calculation, as well as either a justification for using the simple - 'findMinFrom' or a proper root-finding algorithm. - - It would also be nice to add (preferably via its own library) - support for numerical integration and differentiation, so that - tables can be derived from only a PDF (if the end user is - willing to take the performance hit for the convenience). -} {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts, RecordWildCards #-} module Data.Random.Distribution.Ziggurat ( Ziggurat(..) , mkZigguratRec , mkZiggurat , mkZiggurat_ , findBin0 , runZiggurat ) where import Data.Random.Internal.Find import Data.Random.Distribution.Uniform import Data.Random.Distribution import Data.Random.RVar import Data.StorableVector as Vec import Foreign.Storable vec ! i = index vec i data Ziggurat t = Ziggurat { zTable_xs :: Vector t , zTable_x_ratios :: Vector t , zTable_ys :: Vector t , zGetIU :: RVar (Int, t) , zTailDist :: RVar t , zUniform :: t -> t -> RVar t , zFunc :: t -> t , zMirror :: Bool } {-# INLINE runZiggurat #-} runZiggurat :: (Num a, Ord a, Storable a) => Ziggurat a -> RVar a runZiggurat Ziggurat{..} = go where go = do (i,u) <- zGetIU let x = u * zTable_xs ! i if abs u < zTable_x_ratios ! i then return $! x else if i == 0 then if x < 0 then fmap negate zTailDist else zTailDist else do v <- zUniform (zTable_ys ! (i+1)) (zTable_ys ! i) if v < zFunc (abs x) then return $! x else go -- |Build the tables to implement the "ziggurat algorithm" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- * the (one-sided antitone) CDF -- * the inverse of the CDF -- * the number of bins -- * R, the x value of the first bin -- * V, the volume of each bin -- * an RVar providing a random tuple consisting of: -- - a bin index, uniform over [0,c) :: Int -- - a uniformly distributed fractional value, from -1 to 1 if not mirrored, from 0 to 1 otherwise. -- * an RVar sampling from the tail (the region where x > R) mkZiggurat_ :: (RealFloat t, Storable t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> Int -> t -> t -> RVar (Int, t) -> RVar t -> Ziggurat t mkZiggurat_ m f fInv c r v getIU tailDist = z where z = Ziggurat { zTable_xs = zigguratTable f fInv c r v , zTable_x_ratios = precomputeRatios (zTable_xs z) , zTable_ys = Vec.map f (zTable_xs z) , zGetIU = getIU , zUniform = uniform , zFunc = f , zTailDist = tailDist , zMirror = m } -- |Build the tables to implement the "ziggurat algorithm" devised by -- Marsaglia & Tang, attempting to automatically compute the R and V -- values. -- -- Arguments are the same as for |mkZigguratRec|, with an additional -- argument for the tail distribution as a function of the selected -- R value. mkZiggurat :: (RealFloat t, Storable t, Distribution Uniform t) => Bool -> (t -> t) -> (t -> t) -> (t -> t) -> t -> Int -> RVar (Int, t) -> (t -> RVar t) -> Ziggurat t mkZiggurat m f fInv fInt fVol c getIU tailDist = mkZiggurat_ m f fInv c r v getIU (tailDist r) where (r,v) = findBin0 c f fInv fInt fVol -- |Build a lazy recursive ziggurat. Uses a lazily-constructed ziggurat -- as its tail distribution (with another as its tail, ad nauseum). -- -- Arguments: -- -- * flag indicating whether to mirror the distribution -- * the (one-sided antitone) CDF -- * the inverse of the CDF -- * the integral of the CDF (definite, from 0) -- * the estimated volume under the CDF (from 0 to +infinity) -- * the chunk size (number of bins). 64 seems to perform well in practice. -- * an RVar providing a random tuple consisting of: -- - a bin index, uniform over [0,c) :: Int -- - a uniformly distributed fractional value, from -1 to 1 if not mirrored, from 0 to 1 otherwise. mkZigguratRec m f fInv fInt fVol c getIU = mkZiggurat m f fInv fInt fVol c getIU (fix (mkTail m f fInv fInt fVol c getIU)) where fix f = f (fix f) mkTail m f fInv fInt fVol c getIU nextTail r = do x <- rvar (mkZiggurat m f' fInv' fInt' fVol' c getIU nextTail) return (x + r * signum x) where fIntR = fInt r f' x | x < 0 = f r | otherwise = f (x+r) fInv' = subtract r . fInv fInt' x | x < 0 = 0 | otherwise = fInt (x+r) - fIntR fVol' = fVol - fIntR zigguratTable :: (Fractional a, Storable a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> Vector a zigguratTable f fInv c r v = case zigguratXs f fInv c r v of (xs, excess) -> pack xs where epsilon = 1e-3*v zigguratExcess f fInv c r v = snd (zigguratXs f fInv c r v) zigguratXs f fInv c r v = (xs, excess) where xs = Prelude.map x [0..c] -- sample c x ys = Prelude.map f xs x 0 = v / f r x 1 = r x i | i == c = 0 x (i+1) = next i next i = let x_i = xs!!i in if x_i <= 0 then -1 else fInv (ys!!i + (v / x_i)) excess = xs!!(c-1) * (f 0 - ys !! (c-1)) - v precomputeRatios zTable_xs = sample (c-1) $ \i -> zTable_xs!(i+1) / zTable_xs!i where c = Vec.length zTable_xs -- |I suspect this isn't completely right, but it works well so far. -- Search the distribution for an appropriate R and V. -- -- Arguments: -- -- * Number of bins -- * function (one-sided antitone CDF, not necessarily normalized) -- * function inverse -- * function definite integral (from 0 to _) -- * estimate of total volume under function (integral from 0 to infinity) -- -- Result: (R,V) findBin0 cInt f fInv fInt fVol = (r,v r) where c = fromIntegral cInt v r = r * f r + fVol - fInt r -- initial R guess: r0 = findMin (\r -> v r <= fVol / c) -- find a better R: r = findMinFrom r0 1 $ \r -> let e = exc r in e >= 0 && not (isNaN e) exc x = zigguratExcess f fInv cInt x (v x) instance (Num t, Ord t, Storable t) => Distribution Ziggurat t where rvar = runZiggurat