{-# LANGUAGE
        MultiParamTypeClasses,
        RankNTypes,
        FlexibleInstances, FlexibleContexts,
        RecordWildCards, BangPatterns
  #-}

-- |A generic \"ziggurat algorithm\" implementation.  Fairly rough right
--  now.
--
--  There is a lot of room for improvement in 'findBin0' especially.
--  It needs a fair amount of cleanup and elimination of redundant
--  calculation, as well as either a justification for using the simple
--  'findMinFrom' or a proper root-finding algorithm.
--
--  It would also be nice to add (preferably by pulling in an
--  external package) support for numerical integration and
--  differentiation, so that tables can be derived from only a
--  PDF (if the end user is willing to take the performance and
--  accuracy hit for the convenience).
module Data.Random.Distribution.Ziggurat
    ( Ziggurat(..)
    , mkZigguratRec
    , mkZiggurat
    , mkZiggurat_
    , findBin0
    , runZiggurat
    ) where

import Data.Random.Internal.Find

import Data.Random.Distribution.Uniform
import Data.Random.Distribution
import Data.Random.RVar
import Data.Vector.Generic as Vec
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV

-- |A data structure containing all the data that is needed
-- to implement Marsaglia & Tang's \"ziggurat\" algorithm for
-- sampling certain kinds of random distributions.
--
-- The documentation here is probably not sufficient to tell a user exactly
-- how to build one of these from scratch, but it is not really intended to
-- be.  There are several helper functions that will build 'Ziggurat's.
-- The pathologically curious may wish to read the 'runZiggurat' source.
-- That is the ultimate specification of the semantics of all these fields.
data Ziggurat v t = Ziggurat {
        -- |The X locations of each bin in the distribution.  Bin 0 is the
        -- 'infinite' one.
        --
        -- In the case of bin 0, the value given is sort of magical - x[0] is
        -- defined to be V/f(R).  It's not actually the location of any bin,
        -- but a value computed to make the algorithm more concise and slightly
        -- faster by not needing to specially-handle bin 0 quite as often.
        -- If you really need to know why it works, see the 'runZiggurat'
        -- source or \"the literature\" - it's a fairly standard setup.
        Ziggurat v t -> v t
zTable_xs         :: !(v t),
        -- |The ratio of each bin's Y value to the next bin's Y value
        Ziggurat v t -> v t
zTable_y_ratios   :: !(v t),
        -- |The Y value (zFunc x) of each bin
        Ziggurat v t -> v t
zTable_ys         :: !(v t),
        -- |An RVar providing a random tuple consisting of:
        --
        --  * a bin index, uniform over [0,c) :: Int (where @c@ is the
        --    number of bins in the tables)
        --
        --  * a uniformly distributed fractional value, from -1 to 1
        --    if not mirrored, from 0 to 1 otherwise.
        --
        -- This is provided as a single 'RVar' because it can be implemented
        -- more efficiently than naively sampling 2 separate values - a
        -- single random word (64 bits) can be efficiently converted to
        -- a double (using 52 bits) and a bin number (using up to 12 bits),
        -- for example.
        Ziggurat v t -> forall (m :: * -> *). RVarT m (Int, t)
zGetIU            :: !(forall m. RVarT m (Int, t)),

        -- |The distribution for the final \"virtual\" bin
        -- (the ziggurat algorithm does not handle distributions
        -- that wander off to infinity, so another distribution is needed
        -- to handle the last \"bin\" that stretches to infinity)
        Ziggurat v t -> forall (m :: * -> *). RVarT m t
zTailDist         :: (forall m. RVarT m t),

        -- |A copy of the uniform RVar generator for the base type,
        -- so that @Distribution Uniform t@ is not needed when sampling
        -- from a Ziggurat (makes it a bit more self-contained).
        Ziggurat v t -> forall (m :: * -> *). t -> t -> RVarT m t
zUniform          :: !(forall m. t -> t -> RVarT m t),

        -- |The (one-sided antitone) PDF, not necessarily normalized
        Ziggurat v t -> t -> t
zFunc             :: !(t -> t),

        -- |A flag indicating whether the distribution should be
        -- mirrored about the origin (the ziggurat algorithm in
        -- its native form only samples from one-sided distributions.
        -- By mirroring, we can extend it to symmetric distributions
        -- such as the normal distribution)
        Ziggurat v t -> Bool
zMirror           :: !Bool
    }

-- |Sample from the distribution encoded in a 'Ziggurat' data structure.
{-# INLINE runZiggurat #-}
{-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Float  -> RVarT m Float #-}
{-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Double -> RVarT m Double #-}
{-# SPECIALIZE runZiggurat :: Ziggurat  V.Vector Float  -> RVarT m Float #-}
{-# SPECIALIZE runZiggurat :: Ziggurat  V.Vector Double -> RVarT m Double #-}
runZiggurat :: (Num a, Ord a, Vector v a) =>
               Ziggurat v a -> RVarT m a
runZiggurat :: Ziggurat v a -> RVarT m a
runZiggurat !Ziggurat{v a
Bool
a -> a
forall (m :: * -> *). RVarT m a
forall (m :: * -> *). RVarT m (Int, a)
forall (m :: * -> *). a -> a -> RVarT m a
zMirror :: Bool
zFunc :: a -> a
zUniform :: forall (m :: * -> *). a -> a -> RVarT m a
zTailDist :: forall (m :: * -> *). RVarT m a
zGetIU :: forall (m :: * -> *). RVarT m (Int, a)
zTable_ys :: v a
zTable_y_ratios :: v a
zTable_xs :: v a
zMirror :: forall (v :: * -> *) t. Ziggurat v t -> Bool
zFunc :: forall (v :: * -> *) t. Ziggurat v t -> t -> t
zUniform :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). t -> t -> RVarT m t
zTailDist :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m t
zGetIU :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m (Int, t)
zTable_ys :: forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_y_ratios :: forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_xs :: forall (v :: * -> *) t. Ziggurat v t -> v t
..} = RVarT m a
forall (m :: * -> *). RVarT m a
go
    where
        {-# NOINLINE go #-}
        go :: RVarT m a
go = do
            -- Select a bin (I) and a uniform value (U) from -1 to 1
            -- (or 0 to 1 if not mirroring the distribution).
            -- Let X be U scaled to the size of the selected bin.
            (!Int
i,!a
u) <- RVarT m (Int, a)
forall (m :: * -> *). RVarT m (Int, a)
zGetIU

            -- if the uniform value U falls in the area "clearly inside" the
            -- bin, accept X immediately.
            -- Otherwise, depending on the bin selected, use either the
            -- tail distribution or an accept/reject test.
            if a -> a
forall a. Num a => a -> a
abs a
u a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< v a
zTable_y_ratios v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i
                then a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> RVarT m a) -> a -> RVarT m a
forall a b. (a -> b) -> a -> b
$! (a
u a -> a -> a
forall a. Num a => a -> a -> a
* v a
zTable_xs v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i)
                else if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                    then a -> RVarT m a
forall a (m :: * -> *). (Ord a, Num a) => a -> RVarT m a
sampleTail a
u
                    else Int -> a -> RVarT m a
sampleGreyArea Int
i (a -> RVarT m a) -> a -> RVarT m a
forall a b. (a -> b) -> a -> b
$! (a
u a -> a -> a
forall a. Num a => a -> a -> a
* v a
zTable_xs v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i)

        -- when the sample falls in the "grey area" (the area between
        -- the Y values of the selected bin and the bin after that one),
        -- use an accept/reject method based on the target PDF.
        {-# INLINE sampleGreyArea #-}
        sampleGreyArea :: Int -> a -> RVarT m a
sampleGreyArea Int
i a
x = do
            !a
v <- a -> a -> RVarT m a
forall (m :: * -> *). a -> a -> RVarT m a
zUniform (v a
zTable_ys v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) (v a
zTable_ys v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i)
            if a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a -> a
zFunc (a -> a
forall a. Num a => a -> a
abs a
x)
                then a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> RVarT m a) -> a -> RVarT m a
forall a b. (a -> b) -> a -> b
$! a
x
                else RVarT m a
go

        -- if the selected bin is the "infinite" one, call it quits and
        -- defer to the tail distribution (mirroring if needed to ensure
        -- the result has the sign already selected by zGetIU)
        {-# INLINE sampleTail #-}
        sampleTail :: a -> RVarT m a
sampleTail a
x
            | Bool
zMirror Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0  = (a -> a) -> RVarT m a -> RVarT m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate RVarT m a
forall (m :: * -> *). RVarT m a
zTailDist
            | Bool
otherwise         = RVarT m a
forall (m :: * -> *). RVarT m a
zTailDist


-- |Build the tables to implement the \"ziggurat algorithm\" devised by
-- Marsaglia & Tang, attempting to automatically compute the R and V
-- values.
--
-- Arguments:
--
--  * flag indicating whether to mirror the distribution
--
--  * the (one-sided antitone) PDF, not necessarily normalized
--
--  * the inverse of the PDF
--
--  * the number of bins
--
--  * R, the x value of the first bin
--
--  * V, the volume of each bin
--
--  * an RVar providing the 'zGetIU' random tuple
--
--  * an RVar sampling from the tail (the region where x > R)
--
{-# INLINE mkZiggurat_ #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Float  ->  Float) -> (Float  ->  Float) -> Int -> Float  -> Float  -> (forall m. RVarT m (Int,  Float)) -> (forall m. RVarT m Float ) -> Ziggurat UV.Vector Float #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat UV.Vector Double #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Float  ->  Float) -> (Float  ->  Float) -> Int -> Float  -> Float  -> (forall m. RVarT m (Int,  Float)) -> (forall m. RVarT m Float ) -> Ziggurat V.Vector Float #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat V.Vector Double #-}
mkZiggurat_ :: (RealFloat t, Vector v t,
               Distribution Uniform t) =>
              Bool
              -> (t -> t)
              -> (t -> t)
              -> Int
              -> t
              -> t
              -> (forall m. RVarT m (Int, t))
              -> (forall m. RVarT m t)
              -> Ziggurat v t
mkZiggurat_ :: Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
m t -> t
f t -> t
fInv Int
c t
r t
v forall (m :: * -> *). RVarT m (Int, t)
getIU forall (m :: * -> *). RVarT m t
tailDist = Ziggurat :: forall (v :: * -> *) t.
v t
-> v t
-> v t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> (forall (m :: * -> *). t -> t -> RVarT m t)
-> (t -> t)
-> Bool
-> Ziggurat v t
Ziggurat
    { zTable_xs :: v t
zTable_xs         = v t
xs
    , zTable_y_ratios :: v t
zTable_y_ratios   = v t -> v t
forall (v :: * -> *) a. (Vector v a, Fractional a) => v a -> v a
precomputeRatios v t
xs
    , zTable_ys :: v t
zTable_ys         = (t -> t) -> v t -> v t
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
Vec.map t -> t
f v t
xs
    , zGetIU :: forall (m :: * -> *). RVarT m (Int, t)
zGetIU            = forall (m :: * -> *). RVarT m (Int, t)
getIU
    , zUniform :: forall (m :: * -> *). t -> t -> RVarT m t
zUniform          = forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
forall (m :: * -> *). t -> t -> RVarT m t
uniformT
    , zFunc :: t -> t
zFunc             = t -> t
f
    , zTailDist :: forall (m :: * -> *). RVarT m t
zTailDist         = forall (m :: * -> *). RVarT m t
tailDist
    , zMirror :: Bool
zMirror           = Bool
m
    }
    where
        xs :: v t
xs = (t -> t) -> (t -> t) -> Int -> t -> t -> v t
forall a (v :: * -> *).
(Fractional a, Vector v a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable t -> t
f t -> t
fInv Int
c t
r t
v

-- |Build the tables to implement the \"ziggurat algorithm\" devised by
-- Marsaglia & Tang, attempting to automatically compute the R and V
-- values.
--
-- Arguments are the same as for 'mkZigguratRec', with an additional
-- argument for the tail distribution as a function of the selected
-- R value.
mkZiggurat :: (RealFloat t, Vector v t,
               Distribution Uniform t) =>
              Bool
              -> (t -> t)
              -> (t -> t)
              -> (t -> t)
              -> t
              -> Int
              -> (forall m. RVarT m (Int, t))
              -> (forall m. t -> RVarT m t)
              -> Ziggurat v t
mkZiggurat :: Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU forall (m :: * -> *). t -> RVarT m t
tailDist =
    Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
m t -> t
f t -> t
fInv Int
c t
r t
v forall (m :: * -> *). RVarT m (Int, t)
getIU (t -> RVarT m t
forall (m :: * -> *). t -> RVarT m t
tailDist t
r)
        where
            (t
r,t
v) = Int -> (t -> t) -> (t -> t) -> (t -> t) -> t -> (t, t)
forall b.
RealFloat b =>
Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 Int
c t -> t
f t -> t
fInv t -> t
fInt t
fVol

-- |Build a lazy recursive ziggurat.  Uses a lazily-constructed ziggurat
-- as its tail distribution (with another as its tail, ad nauseam).
--
-- Arguments:
--
--  * flag indicating whether to mirror the distribution
--
--  * the (one-sided antitone) PDF, not necessarily normalized
--
--  * the inverse of the PDF
--
--  * the integral of the PDF (definite, from 0)
--
--  * the estimated volume under the PDF (from 0 to +infinity)
--
--  * the chunk size (number of bins in each layer).  64 seems to
--    perform well in practice.
--
--  * an RVar providing the 'zGetIU' random tuple
--
mkZigguratRec ::
  (RealFloat t, Vector v t,
   Distribution Uniform t) =>
  Bool
  -> (t -> t)
  -> (t -> t)
  -> (t -> t)
  -> t
  -> Int
  -> (forall m. RVarT m (Int, t))
  -> Ziggurat v t
mkZigguratRec :: Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU = Ziggurat v t
z
        where
            fix :: ((forall m. a -> RVarT m a) -> (forall m. a -> RVarT m a)) -> (forall m. a -> RVarT m a)
            fix :: ((forall (m :: * -> *). a -> RVarT m a)
 -> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g = (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g (((forall (m :: * -> *). a -> RVarT m a)
 -> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
forall a.
((forall (m :: * -> *). a -> RVarT m a)
 -> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g)
            z :: Ziggurat v t
z = Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU (((forall (m :: * -> *). t -> RVarT m t)
 -> forall (m :: * -> *). t -> RVarT m t)
-> forall (m :: * -> *). t -> RVarT m t
forall a.
((forall (m :: * -> *). a -> RVarT m a)
 -> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
-> (forall (m :: * -> *). t -> RVarT m t)
-> forall (m :: * -> *). t -> RVarT m t
forall a (v :: * -> *).
(RealFloat a, Vector v a, Distribution Uniform a) =>
Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
-> (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
mkTail Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU Ziggurat v t
z))

mkTail ::
    (RealFloat a, Vector v a, Distribution Uniform a) =>
    Bool
    -> (a -> a) -> (a -> a) -> (a -> a)
    -> a
    -> Int
    -> (forall m. RVarT m (Int, a))
    -> Ziggurat v a
    -> (forall m. a -> RVarT m a)
    -> (forall m. a -> RVarT m a)
mkTail :: Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
-> (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
mkTail Bool
m a -> a
f a -> a
fInv a -> a
fInt a
fVol Int
c forall (m :: * -> *). RVarT m (Int, a)
getIU Ziggurat v a
typeRep forall (m :: * -> *). a -> RVarT m a
nextTail a
r = do
     a
x <- Ziggurat v a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> (forall (m :: * -> *). a -> RVarT m a)
-> Ziggurat v a
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m a -> a
f' a -> a
fInv' a -> a
fInt' a
fVol' Int
c forall (m :: * -> *). RVarT m (Int, a)
getIU forall (m :: * -> *). a -> RVarT m a
nextTail Ziggurat v a -> Ziggurat v a -> Ziggurat v a
forall a. a -> a -> a
`asTypeOf` Ziggurat v a
typeRep)
     a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
r a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Num a => a -> a
signum a
x)
        where
            fIntR :: a
fIntR = a -> a
fInt a
r

            f' :: a -> a
f' a
x    | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0     = a -> a
f a
r
                    | Bool
otherwise = a -> a
f (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
r)
            fInv' :: a -> a
fInv' = a -> a -> a
forall a. Num a => a -> a -> a
subtract a
r (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
fInv
            fInt' :: a -> a
fInt' a
x | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0     = a
0
                    | Bool
otherwise = a -> a
fInt (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
r) a -> a -> a
forall a. Num a => a -> a -> a
- a
fIntR

            fVol' :: a
fVol' = a
fVol a -> a -> a
forall a. Num a => a -> a -> a
- a
fIntR


zigguratTable :: (Fractional a, Vector v a, Ord a) =>
                 (a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable :: (a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable a -> a
f a -> a
fInv Int
c a
r a
v = case (a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v of
    ([a]
xs, a
_excess) -> [a] -> v a
forall (v :: * -> *) a. Vector v a => [a] -> v a
fromList [a]
xs

zigguratExcess :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess :: (a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess a -> a
f a -> a
fInv Int
c a
r a
v = ([a], a) -> a
forall a b. (a, b) -> b
snd ((a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v)

zigguratXs :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs :: (a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v = ([a]
xs, a
excess)
    where
        xs :: [a]
xs = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map Int -> a
x [Int
0..Int
c] -- sample c x
        ys :: [a]
ys = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map a -> a
f [a]
xs

        x :: Int -> a
x Int
0 = a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
f a
r
        x Int
1 = a
r
        x Int
i | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c = a
0
        x Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>  Int
1 = Int -> a
next (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        x Int
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"zigguratXs: programming error! this case should be impossible!"

        next :: Int -> a
next Int
i = let x_i :: a
x_i = [a]
xs[a] -> Int -> a
forall a. [a] -> Int -> a
!!Int
i
                  in if a
x_i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 then -a
1 else a -> a
fInv ([a]
ys[a] -> Int -> a
forall a. [a] -> Int -> a
!!Int
i a -> a -> a
forall a. Num a => a -> a -> a
+ (a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
x_i))

        excess :: a
excess = [a]
xs[a] -> Int -> a
forall a. [a] -> Int -> a
!!(Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) a -> a -> a
forall a. Num a => a -> a -> a
* (a -> a
f a
0 a -> a -> a
forall a. Num a => a -> a -> a
- [a]
ys [a] -> Int -> a
forall a. [a] -> Int -> a
!! (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) a -> a -> a
forall a. Num a => a -> a -> a
- a
v


precomputeRatios :: (Vector v a, Fractional a) => v a -> v a
precomputeRatios :: v a -> v a
precomputeRatios v a
zTable_xs = Int -> (Int -> a) -> v a
forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
generate (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ((Int -> a) -> v a) -> (Int -> a) -> v a
forall a b. (a -> b) -> a -> b
$ \Int
i -> v a
zTable_xsv a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) a -> a -> a
forall a. Fractional a => a -> a -> a
/ v a
zTable_xsv a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
!Int
i
    where
        c :: Int
c = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Vec.length v a
zTable_xs

-- |I suspect this isn't completely right, but it works well so far.
-- Search the distribution for an appropriate R and V.
--
-- Arguments:
--
--  * Number of bins
--
--  * target function (one-sided antitone PDF, not necessarily normalized)
--
--  * function inverse
--
--  * function definite integral (from 0 to _)
--
--  * estimate of total volume under function (integral from 0 to infinity)
--
-- Result: (R,V)
findBin0 :: (RealFloat b) =>
    Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 :: Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 Int
cInt b -> b
f b -> b
fInv b -> b
fInt b
fVol = (b
rMin,b -> b
v b
rMin)
    where
        c :: b
c = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
cInt
        v :: b -> b
v b
r = b
r b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
f b
r b -> b -> b
forall a. Num a => a -> a -> a
+ b
fVol b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
fInt b
r

        -- initial R guess:
        r0 :: b
r0 = (b -> Bool) -> b
forall a. (Fractional a, Ord a) => (a -> Bool) -> a
findMin (\b
r -> b -> b
v b
r b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
fVol b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
c)
        -- find a better R:
        rMin :: b
rMin = b -> b -> (b -> Bool) -> b
forall a. (Fractional a, Ord a) => a -> a -> (a -> Bool) -> a
findMinFrom b
r0 b
1 ((b -> Bool) -> b) -> (b -> Bool) -> b
forall a b. (a -> b) -> a -> b
$ \b
r ->
            let e :: b
e = b -> b
exc b
r
             in b
e b -> b -> Bool
forall a. Ord a => a -> a -> Bool
>= b
0 Bool -> Bool -> Bool
&& Bool -> Bool
not (b -> Bool
forall a. RealFloat a => a -> Bool
isNaN b
e)

        exc :: b -> b
exc b
x = (b -> b) -> (b -> b) -> Int -> b -> b -> b
forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess b -> b
f b -> b
fInv Int
cInt b
x (b -> b
v b
x)

instance (Num t, Ord t, Vector v t) => Distribution (Ziggurat v) t where
    rvar :: Ziggurat v t -> RVar t
rvar = Ziggurat v t -> RVar t
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat